Commit dfe29f5e authored by Tri Dao's avatar Tri Dao
Browse files

[Gen] Don't use ft_attention, use flash_attn_with_kvcache instead

parent 3250ff3d
...@@ -6,3 +6,9 @@ FasterTransformer v5.2.1 for benchmarking purpose. ...@@ -6,3 +6,9 @@ FasterTransformer v5.2.1 for benchmarking purpose.
```sh ```sh
cd csrc/ft_attention && pip install . cd csrc/ft_attention && pip install .
``` ```
As of 2023-09-17, this extension is no longer used in the FlashAttention repo.
FlashAttention now has implemented
[`flash_attn_with_kvcache`](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attention_interface.py)
with all the features of this `ft_attention` kernel (and more).
...@@ -32,11 +32,6 @@ try: ...@@ -32,11 +32,6 @@ try:
except ImportError: except ImportError:
RotaryEmbedding = None RotaryEmbedding = None
try:
import ft_attention
except ImportError:
ft_attention = None
class FlashSelfAttention(nn.Module): class FlashSelfAttention(nn.Module):
"""Implement the scaled dot product attention with softmax. """Implement the scaled dot product attention with softmax.
...@@ -314,14 +309,7 @@ def _update_kv_cache(kv, inference_params, layer_idx): ...@@ -314,14 +309,7 @@ def _update_kv_cache(kv, inference_params, layer_idx):
) )
inference_params.key_value_memory_dict[layer_idx] = kv_cache inference_params.key_value_memory_dict[layer_idx] = kv_cache
else: else:
if not inference_params.fused_ft_kernel: kv_cache = inference_params.key_value_memory_dict[layer_idx]
kv_cache = inference_params.key_value_memory_dict[layer_idx]
else:
# For FT, k_cache has shape (b, h, headdim / packsize, s, packsize)
# where packsize = 4 if fp32, 8 if fp16 or bf16.
# v_cache has shape (b, h, s, headdim)
k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
kv_cache = None
# Adjust key and value for inference # Adjust key and value for inference
batch_start = inference_params.batch_size_offset batch_start = inference_params.batch_size_offset
batch_end = batch_start + kv.shape[0] batch_end = batch_start + kv.shape[0]
...@@ -329,79 +317,9 @@ def _update_kv_cache(kv, inference_params, layer_idx): ...@@ -329,79 +317,9 @@ def _update_kv_cache(kv, inference_params, layer_idx):
sequence_end = sequence_start + kv.shape[1] sequence_end = sequence_start + kv.shape[1]
assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0]) assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2]) assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
# Copy key and values. assert kv_cache is not None
if not inference_params.fused_ft_kernel: kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
assert kv_cache is not None return kv_cache[batch_start:batch_end, :sequence_end, ...]
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
return kv
else:
assert inference_params.sequence_len_offset == 0
# FT kernel requires different layouts for the k_cache and v_cache.
assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if kv.dtype == torch.float32 else 8
if kv_cache is not None:
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
k_cache = rearrange(
kv_cache[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize
).contiguous()
v_cache = rearrange(kv_cache[:, :, 1], "b s h d -> b h s d").contiguous()
inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache)
else:
k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize
)
v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(
kv[:, :, 1], "b s h d -> b h s d"
)
return kv
def _apply_rotary_single_query_attention(
qkv,
inference_params,
layer_idx,
rotary_emb_dim,
rotary_emb_base,
kv=None,
rotary_emb_interleaved=False,
):
"""
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
q of shape (batch_size, 1, nheads, head_dim)
kv: (batch_size, 1, 2, nheads_kv, head_dim)
"""
assert inference_params.fused_ft_kernel
assert ft_attention is not None
if kv is None:
q, k, v = rearrange(qkv, "b 1 three h d -> b three h d").unbind(dim=1)
else:
q = rearrange(qkv, "b 1 h d -> b h d")
k, v = rearrange(kv, "b 1 two h d -> b two h d").unbind(dim=1)
batch_start = inference_params.batch_size_offset
batch_end = batch_start + q.shape[0]
k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
lengths_per_sample = (
inference_params.lengths_per_sample[batch_start:batch_end]
if inference_params.lengths_per_sample is not None
else None
)
context = ft_attention.single_query_attention(
q,
k,
v,
k_cache[batch_start:batch_end],
v_cache[batch_start:batch_end],
lengths_per_sample,
None, # rotary_cos_
None, # rotary_sin_
None, # nnz_head_idx
inference_params.sequence_len_offset,
rotary_emb_dim,
rotary_emb_base,
not rotary_emb_interleaved, # neox_rotary_style
)
return rearrange(context, "b h d -> b 1 h d")
class MHA(nn.Module): class MHA(nn.Module):
...@@ -502,36 +420,18 @@ class MHA(nn.Module): ...@@ -502,36 +420,18 @@ class MHA(nn.Module):
) )
self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs) self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, fused_ft_kernel=True): def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
dtype = self.out_proj.weight.dtype if dtype is None else dtype dtype = self.out_proj.weight.dtype if dtype is None else dtype
device = self.out_proj.weight.device device = self.out_proj.weight.device
if not fused_ft_kernel: return torch.empty(
return torch.empty( batch_size,
batch_size, max_seqlen,
max_seqlen, 2,
2, self.num_heads_kv,
self.num_heads_kv, self.head_dim,
self.head_dim, dtype=dtype,
dtype=dtype, device=device,
device=device, )
)
else:
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if dtype == torch.float32 else 8
assert self.head_dim % packsize == 0
k_cache = torch.empty(
batch_size,
self.num_heads_kv,
self.head_dim // packsize,
max_seqlen,
packsize,
dtype=dtype,
device=device,
)
v_cache = torch.empty(
batch_size, self.num_heads_kv, max_seqlen, self.head_dim, dtype=dtype, device=device
)
return k_cache, v_cache
def _update_kv_cache(self, kv, inference_params): def _update_kv_cache(self, kv, inference_params):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
...@@ -539,27 +439,46 @@ class MHA(nn.Module): ...@@ -539,27 +439,46 @@ class MHA(nn.Module):
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
return _update_kv_cache(kv, inference_params, self.layer_idx) return _update_kv_cache(kv, inference_params, self.layer_idx)
def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None): def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
""" """
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
q of shape (batch_size, 1, nheads, head_dim) q: (batch_size, seqlen_q, nheads, head_dim)
kv: (batch_size, 1, 2, nheads_kv, head_dim) kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
""" """
rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0 assert inference_params is not None and inference_params.sequence_len_offset > 0
return _apply_rotary_single_query_attention( assert self.use_flash_attn
qkv, if self.rotary_emb_dim > 0:
inference_params, assert self.rotary_emb.scale is None, "This code path does not support xPos"
self.layer_idx, self.rotary_emb._update_cos_sin_cache(
self.rotary_emb_dim, inference_params.max_sequence_len, device=q.device, dtype=q.dtype
rotary_emb_base, )
kv=kv, rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
rotary_emb_interleaved=self.rotary_emb.interleaved else:
if self.rotary_emb_dim > 0 rotary_cos, rotary_sin = None, None
else False, batch = q.shape[0]
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset
) )
context = flash_attn_with_kvcache(
q,
kv_cache[:, :, 0],
kv_cache[:, :, 1],
kv[:, :, 0],
kv[:, :, 1],
rotary_cos=rotary_cos,
rotary_sin=rotary_sin,
cache_seqlens=cache_seqlens,
softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal,
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
)
return context
def _update_kvcache_attention(self, q, kv, inference_params): def _update_kvcache_attention(self, q, kv, inference_params):
"""Write kv to inference_params, then do attention """ """Write kv to inference_params, then do attention"""
if ( if (
inference_params.sequence_len_offset == 0 inference_params.sequence_len_offset == 0
or flash_attn_with_kvcache is None or flash_attn_with_kvcache is None
...@@ -663,7 +582,8 @@ class MHA(nn.Module): ...@@ -663,7 +582,8 @@ class MHA(nn.Module):
if ( if (
inference_params is None inference_params is None
or inference_params.sequence_len_offset == 0 or inference_params.sequence_len_offset == 0
or not inference_params.fused_ft_kernel or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
or not self.use_flash_attn
): ):
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
qkv = self.rotary_emb( qkv = self.rotary_emb(
...@@ -679,7 +599,9 @@ class MHA(nn.Module): ...@@ -679,7 +599,9 @@ class MHA(nn.Module):
qkv[:, :, 0], qkv[:, :, 1:], inference_params qkv[:, :, 0], qkv[:, :, 1:], inference_params
) )
else: else:
context = self._apply_rotary_single_query_attention(qkv, inference_params) context = self._apply_rotary_update_kvcache_attention(
qkv[:, :, 0], qkv[:, :, 1:], inference_params
)
else: else:
if self.cross_attn: if self.cross_attn:
if not self.return_residual: if not self.return_residual:
...@@ -711,7 +633,8 @@ class MHA(nn.Module): ...@@ -711,7 +633,8 @@ class MHA(nn.Module):
if ( if (
inference_params is None inference_params is None
or inference_params.sequence_len_offset == 0 or inference_params.sequence_len_offset == 0
or not inference_params.fused_ft_kernel or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
or not self.use_flash_attn
): ):
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
q, kv = self.rotary_emb( q, kv = self.rotary_emb(
...@@ -727,7 +650,7 @@ class MHA(nn.Module): ...@@ -727,7 +650,7 @@ class MHA(nn.Module):
else: else:
context = self._update_kvcache_attention(q, kv, inference_params) context = self._update_kvcache_attention(q, kv, inference_params)
else: else:
context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv) context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
out = self.out_proj(rearrange(context, "... h d -> ... (h d)")) out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
return out if not self.return_residual else (out, x) return out if not self.return_residual else (out, x)
...@@ -825,73 +748,65 @@ class ParallelMHA(nn.Module): ...@@ -825,73 +748,65 @@ class ParallelMHA(nn.Module):
**factory_kwargs, **factory_kwargs,
) )
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, fused_ft_kernel=True): def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
dtype = self.out_proj.weight.dtype if dtype is None else dtype dtype = self.out_proj.weight.dtype if dtype is None else dtype
device = self.out_proj.weight.device device = self.out_proj.weight.device
if not fused_ft_kernel: return torch.empty(
return torch.empty( batch_size,
batch_size, max_seqlen,
max_seqlen, 2,
2, self.num_heads_kv_per_rank,
self.num_heads_kv_per_rank, self.head_dim,
self.head_dim, dtype=dtype,
dtype=dtype, device=device,
device=device, )
)
else:
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if dtype == torch.float32 else 8
assert self.head_dim % packsize == 0
k_cache = torch.empty(
batch_size,
self.num_heads_kv_per_rank,
self.head_dim // packsize,
max_seqlen,
packsize,
dtype=dtype,
device=device,
)
v_cache = torch.empty(
batch_size,
self.num_heads_kv_per_rank,
max_seqlen,
self.head_dim,
dtype=dtype,
device=device,
)
return k_cache, v_cache
def _update_kv_cache(self, kv, inference_params): def _update_kv_cache(self, kv, inference_params):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
return _update_kv_cache(kv, inference_params, self.layer_idx) return _update_kv_cache(kv, inference_params, self.layer_idx)
def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None): def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
""" """
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
q of shape (batch_size, 1, nheads, head_dim) q: (batch_size, seqlen_q, nheads, head_dim)
kv: (batch_size, 1, 2, nheads_kv, head_dim) kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
""" """
rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0 assert inference_params is not None and inference_params.sequence_len_offset > 0
return _apply_rotary_single_query_attention( assert self.use_flash_attn
qkv, if self.rotary_emb_dim > 0:
inference_params, assert self.rotary_emb.scale is None, "This code path does not support xPos"
self.layer_idx, self.rotary_emb._update_cos_sin_cache(
self.rotary_emb_dim, inference_params.max_sequence_len, device=q.device, dtype=q.dtype
rotary_emb_base, )
kv=kv, rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
rotary_emb_interleaved=self.rotary_emb.interleaved else:
if self.rotary_emb_dim > 0 rotary_cos, rotary_sin = None, None
else False, batch = q.shape[0]
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset
)
context = flash_attn_with_kvcache(
q,
kv_cache[:, :, 0],
kv_cache[:, :, 1],
kv[:, :, 0],
kv[:, :, 1],
rotary_cos=rotary_cos,
rotary_sin=rotary_sin,
cache_seqlens=cache_seqlens,
softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal,
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
) )
return context
def _update_kvcache_attention(self, q, kv, inference_params): def _update_kvcache_attention(self, q, kv, inference_params):
"""Write kv to inference_params, then do attention """ """Write kv to inference_params, then do attention"""
if ( if inference_params.sequence_len_offset == 0 or not self.use_flash_attn:
inference_params.sequence_len_offset == 0
or flash_attn_with_kvcache is None
or not self.use_flash_attn
):
# TODO: this only uses sequence_len_offset and not lengths_per_sample. # TODO: this only uses sequence_len_offset and not lengths_per_sample.
kv = self._update_kv_cache(kv, inference_params) kv = self._update_kv_cache(kv, inference_params)
return self.inner_cross_attn(q, kv) return self.inner_cross_attn(q, kv)
...@@ -943,7 +858,8 @@ class ParallelMHA(nn.Module): ...@@ -943,7 +858,8 @@ class ParallelMHA(nn.Module):
if ( if (
inference_params is None inference_params is None
or inference_params.sequence_len_offset == 0 or inference_params.sequence_len_offset == 0
or not inference_params.fused_ft_kernel or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
or not self.use_flash_attn
): ):
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
qkv = self.rotary_emb( qkv = self.rotary_emb(
...@@ -959,7 +875,9 @@ class ParallelMHA(nn.Module): ...@@ -959,7 +875,9 @@ class ParallelMHA(nn.Module):
qkv[:, :, 0], qkv[:, :, 1:], inference_params qkv[:, :, 0], qkv[:, :, 1:], inference_params
) )
else: else:
context = self._apply_rotary_single_query_attention(qkv, inference_params) context = self._apply_rotary_update_kvcache_attention(
qkv[:, :, 0], qkv[:, :, 1:], inference_params
)
else: else:
q = rearrange( q = rearrange(
qkv[..., : self.num_heads_per_rank * self.head_dim], qkv[..., : self.num_heads_per_rank * self.head_dim],
...@@ -975,7 +893,8 @@ class ParallelMHA(nn.Module): ...@@ -975,7 +893,8 @@ class ParallelMHA(nn.Module):
if ( if (
inference_params is None inference_params is None
or inference_params.sequence_len_offset == 0 or inference_params.sequence_len_offset == 0
or not inference_params.fused_ft_kernel or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
or not self.use_flash_attn
): ):
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
q, kv = self.rotary_emb( q, kv = self.rotary_emb(
...@@ -991,7 +910,7 @@ class ParallelMHA(nn.Module): ...@@ -991,7 +910,7 @@ class ParallelMHA(nn.Module):
else: else:
context = self._update_kvcache_attention(q, kv, inference_params) context = self._update_kvcache_attention(q, kv, inference_params)
else: else:
context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv) context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
context = rearrange(context, "b s h d -> b s (h d)") context = rearrange(context, "b s h d -> b s (h d)")
if seqlen is not None: if seqlen is not None:
context = rearrange(context, "b s d -> (b s) d") context = rearrange(context, "b s d -> (b s) d")
......
...@@ -25,7 +25,6 @@ class InferenceParams: ...@@ -25,7 +25,6 @@ class InferenceParams:
sequence_len_offset: int = 0 sequence_len_offset: int = 0
batch_size_offset: int = 0 batch_size_offset: int = 0
key_value_memory_dict: dict = field(default_factory=dict) key_value_memory_dict: dict = field(default_factory=dict)
fused_ft_kernel: bool = False
lengths_per_sample: Optional[Tensor] = None lengths_per_sample: Optional[Tensor] = None
...@@ -96,7 +95,6 @@ def decode( ...@@ -96,7 +95,6 @@ def decode(
teacher_outputs=None, teacher_outputs=None,
vocab_size=None, vocab_size=None,
tensor_parallel=1, tensor_parallel=1,
fused_ft_kernel=False,
cg=False, cg=False,
enable_timing=False, enable_timing=False,
): ):
...@@ -127,7 +125,6 @@ def decode( ...@@ -127,7 +125,6 @@ def decode(
seqlen_og, seqlen_og,
max_length, max_length,
tensor_parallel=tensor_parallel, tensor_parallel=tensor_parallel,
fused_ft_kernel=fused_ft_kernel,
) )
inference_params = model._decoding_cache.inference_params inference_params = model._decoding_cache.inference_params
inference_params.max_sequence_len = max_length inference_params.max_sequence_len = max_length
...@@ -135,9 +132,7 @@ def decode( ...@@ -135,9 +132,7 @@ def decode(
inference_params.sequence_len_offset = 0 inference_params.sequence_len_offset = 0
inference_params.lengths_per_sample.zero_() inference_params.lengths_per_sample.zero_()
else: else:
inference_params = InferenceParams( inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size)
max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel
)
def get_logits(input_ids, inference_params): def get_logits(input_ids, inference_params):
decoding = inference_params.sequence_len_offset > 0 decoding = inference_params.sequence_len_offset > 0
...@@ -273,7 +268,6 @@ def decode_speculative( ...@@ -273,7 +268,6 @@ def decode_speculative(
eos_token_id=None, eos_token_id=None,
vocab_size=None, vocab_size=None,
tensor_parallel=1, tensor_parallel=1,
fused_ft_kernel=False,
cg=False, cg=False,
enable_timing=False, enable_timing=False,
debug=False, debug=False,
...@@ -307,23 +301,17 @@ def decode_speculative( ...@@ -307,23 +301,17 @@ def decode_speculative(
seqlen_og, seqlen_og,
max_length, max_length,
tensor_parallel=tensor_parallel, tensor_parallel=tensor_parallel,
fused_ft_kernel=fused_ft_kernel,
) )
inference_params_draft = model_draft._decoding_cache.inference_params inference_params_draft = model_draft._decoding_cache.inference_params
inference_params_draft.max_sequence_len = max_length inference_params_draft.max_sequence_len = max_length
inference_params_draft.max_batch_size = batch_size inference_params_draft.max_batch_size = batch_size
inference_params_draft.sequence_len_offset = 0 inference_params_draft.sequence_len_offset = 0
# fused_ft_kernel doesn't support passing in multiple tokens at once inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size)
inference_params = InferenceParams(
max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=False
)
else: else:
inference_params_draft = InferenceParams( inference_params_draft = InferenceParams(
max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel max_sequence_len=max_length, max_batch_size=batch_size
)
inference_params = InferenceParams(
max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=False
) )
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size)
def logits_forward_fn(model, input_ids, position_ids, inference_params, cg=False): def logits_forward_fn(model, input_ids, position_ids, inference_params, cg=False):
if not cg: if not cg:
...@@ -606,7 +594,6 @@ def allocate_inference_cache( ...@@ -606,7 +594,6 @@ def allocate_inference_cache(
layers: Union[int, Sequence], layers: Union[int, Sequence],
device, device,
dtype=torch.float16, dtype=torch.float16,
fused_ft_kernel=False,
): ):
assert dtype in [torch.float16, torch.bfloat16, torch.float32] assert dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if dtype == torch.float32 else 8 packsize = 4 if dtype == torch.float32 else 8
...@@ -616,15 +603,7 @@ def allocate_inference_cache( ...@@ -616,15 +603,7 @@ def allocate_inference_cache(
kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim) kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim)
if isinstance(layers, int): if isinstance(layers, int):
layers = range(layers) layers = range(layers)
return { return {i: torch.empty(kv_cache_shape, device=device, dtype=dtype) for i in layers}
i: (
torch.empty(k_cache_shape, device=device, dtype=dtype),
torch.empty(v_cache_shape, device=device, dtype=dtype),
)
if fused_ft_kernel
else torch.empty(kv_cache_sahpe, device=device, dtype=dtype)
for i in layers
}
def seqlen_to_seqlen_type(seqlen: int) -> int: def seqlen_to_seqlen_type(seqlen: int) -> int:
...@@ -633,12 +612,12 @@ def seqlen_to_seqlen_type(seqlen: int) -> int: ...@@ -633,12 +612,12 @@ def seqlen_to_seqlen_type(seqlen: int) -> int:
Arguments: Arguments:
seqlen: int seqlen: int
""" """
return 0 if seqlen < 32 else (1 if seqlen < 2048 else 2) return 0
def seqlen_type_to_max_seqlen(seqlen_type: int) -> int: def seqlen_type_to_max_seqlen(seqlen_type: int) -> int:
assert seqlen_type in [0, 1, 2] assert seqlen_type in [0]
return 32 if seqlen_type == 0 else (2048 if seqlen_type == 1 else 2**32) return 2**32
@dataclass @dataclass
...@@ -663,7 +642,6 @@ def update_graph_cache( ...@@ -663,7 +642,6 @@ def update_graph_cache(
tensor_parallel=1, tensor_parallel=1,
dtype=None, dtype=None,
n_warmups=2, n_warmups=2,
fused_ft_kernel=False,
): ):
if cache is None: if cache is None:
cache = DecodingCGCache() cache = DecodingCGCache()
...@@ -683,9 +661,7 @@ def update_graph_cache( ...@@ -683,9 +661,7 @@ def update_graph_cache(
cache.device, cache.dtype = device, dtype cache.device, cache.dtype = device, dtype
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
if hasattr(model, "allocate_inference_cache"): if hasattr(model, "allocate_inference_cache"):
inf_cache = model.allocate_inference_cache( inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
batch_size, max_seqlen, dtype, fused_ft_kernel=fused_ft_kernel
)
else: else:
headdim = getattr( headdim = getattr(
model.config, model.config,
...@@ -700,7 +676,6 @@ def update_graph_cache( ...@@ -700,7 +676,6 @@ def update_graph_cache(
model.config.num_hidden_layers, model.config.num_hidden_layers,
device, device,
dtype, dtype,
fused_ft_kernel=fused_ft_kernel,
) )
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
cache.inference_params = InferenceParams( cache.inference_params = InferenceParams(
...@@ -708,7 +683,6 @@ def update_graph_cache( ...@@ -708,7 +683,6 @@ def update_graph_cache(
max_batch_size=batch_size, max_batch_size=batch_size,
sequence_len_offset=seqlen_og, sequence_len_offset=seqlen_og,
key_value_memory_dict=inf_cache, key_value_memory_dict=inf_cache,
fused_ft_kernel=fused_ft_kernel,
lengths_per_sample=lengths_per_sample, lengths_per_sample=lengths_per_sample,
) )
cache.mempool = torch.cuda.graphs.graph_pool_handle() cache.mempool = torch.cuda.graphs.graph_pool_handle()
......
...@@ -122,10 +122,10 @@ if not SKIP_CUDA_BUILD: ...@@ -122,10 +122,10 @@ if not SKIP_CUDA_BUILD:
# cc_flag.append("arch=compute_75,code=sm_75") # cc_flag.append("arch=compute_75,code=sm_75")
cc_flag.append("-gencode") cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80") cc_flag.append("arch=compute_80,code=sm_80")
if CUDA_HOME is not None: # if CUDA_HOME is not None:
if bare_metal_version >= Version("11.8"): # if bare_metal_version >= Version("11.8"):
cc_flag.append("-gencode") # cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90") # cc_flag.append("arch=compute_90,code=sm_90")
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
# torch._C._GLIBCXX_USE_CXX11_ABI # torch._C._GLIBCXX_USE_CXX11_ABI
......
...@@ -217,9 +217,8 @@ def test_baichuan_parallel_forward(model_name, world_size): ...@@ -217,9 +217,8 @@ def test_baichuan_parallel_forward(model_name, world_size):
).abs().max().item() ).abs().max().item()
@pytest.mark.parametrize("fused_ft_kernel", [False, True])
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"]) @pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"])
def test_baichuan_generation(model_name, fused_ft_kernel): def test_baichuan_generation(model_name):
dtype = torch.float16 dtype = torch.float16
device = "cuda" device = "cuda"
config = baichuan_config_to_gpt2_config( config = baichuan_config_to_gpt2_config(
...@@ -236,8 +235,8 @@ def test_baichuan_generation(model_name, fused_ft_kernel): ...@@ -236,8 +235,8 @@ def test_baichuan_generation(model_name, fused_ft_kernel):
torch.manual_seed(0) torch.manual_seed(0)
batch_size = 1 batch_size = 1
seqlen = 100 seqlen = 2048
max_length = 150 max_length = 2048 + 150
input_ids = torch.randint( input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
) )
...@@ -285,7 +284,6 @@ def test_baichuan_generation(model_name, fused_ft_kernel): ...@@ -285,7 +284,6 @@ def test_baichuan_generation(model_name, fused_ft_kernel):
input_ids=input_ids, input_ids=input_ids,
max_length=max_length, max_length=max_length,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
enable_timing=True, enable_timing=True,
...@@ -296,16 +294,13 @@ def test_baichuan_generation(model_name, fused_ft_kernel): ...@@ -296,16 +294,13 @@ def test_baichuan_generation(model_name, fused_ft_kernel):
# Capture graph outside the timing loop # Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache( model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=fused_ft_kernel
)
print("With CUDA graph") print("With CUDA graph")
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
out_cg = model.generate( out_cg = model.generate(
input_ids=input_ids, input_ids=input_ids,
max_length=max_length, max_length=max_length,
fused_ft_kernel=fused_ft_kernel,
cg=True, cg=True,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
...@@ -403,9 +398,7 @@ def test_baichuan_parallel_generation(model_name, world_size): ...@@ -403,9 +398,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
# Capture graph outside the timing loop # Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache( model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=False
)
print("With CUDA graph") print("With CUDA graph")
out_cg = model.generate( out_cg = model.generate(
input_ids=input_ids, input_ids=input_ids,
......
...@@ -141,7 +141,6 @@ def test_bigcode_generation(model_name): ...@@ -141,7 +141,6 @@ def test_bigcode_generation(model_name):
input_ids=input_ids, input_ids=input_ids,
max_length=max_length, max_length=max_length,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
fused_ft_kernel=True,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
enable_timing=True, enable_timing=True,
...@@ -159,7 +158,6 @@ def test_bigcode_generation(model_name): ...@@ -159,7 +158,6 @@ def test_bigcode_generation(model_name):
out_cg = model.generate( out_cg = model.generate(
input_ids=input_ids, input_ids=input_ids,
max_length=max_length, max_length=max_length,
fused_ft_kernel=True,
cg=True, cg=True,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
......
...@@ -242,7 +242,6 @@ def test_falcon_generation(model_name): ...@@ -242,7 +242,6 @@ def test_falcon_generation(model_name):
input_ids=input_ids, input_ids=input_ids,
max_length=max_length, max_length=max_length,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
fused_ft_kernel=True,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
enable_timing=True, enable_timing=True,
...@@ -253,16 +252,13 @@ def test_falcon_generation(model_name): ...@@ -253,16 +252,13 @@ def test_falcon_generation(model_name):
# Capture graph outside the timing loop # Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache( model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True
)
print("With CUDA graph") print("With CUDA graph")
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
out_cg = model.generate( out_cg = model.generate(
input_ids=input_ids, input_ids=input_ids,
max_length=max_length, max_length=max_length,
fused_ft_kernel=True,
cg=True, cg=True,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
...@@ -349,7 +345,6 @@ def test_falcon_parallel_generation(model_name, world_size): ...@@ -349,7 +345,6 @@ def test_falcon_parallel_generation(model_name, world_size):
max_length=max_length, max_length=max_length,
tensor_parallel=world_size, tensor_parallel=world_size,
vocab_size=config.vocab_size, vocab_size=config.vocab_size,
fused_ft_kernel=True,
# teacher_outputs=out_hf.sequences, # teacher_outputs=out_hf.sequences,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
...@@ -358,16 +353,13 @@ def test_falcon_parallel_generation(model_name, world_size): ...@@ -358,16 +353,13 @@ def test_falcon_parallel_generation(model_name, world_size):
# Capture graph outside the timing loop # Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache( model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True
)
print("With CUDA graph") print("With CUDA graph")
out_cg = model.generate( out_cg = model.generate(
input_ids=input_ids, input_ids=input_ids,
max_length=max_length, max_length=max_length,
tensor_parallel=world_size, tensor_parallel=world_size,
vocab_size=config.vocab_size, vocab_size=config.vocab_size,
fused_ft_kernel=True,
cg=True, cg=True,
# teacher_outputs=out_hf.sequences, # teacher_outputs=out_hf.sequences,
return_dict_in_generate=True, return_dict_in_generate=True,
......
...@@ -134,14 +134,12 @@ def test_gpt2_optimized(model_name): ...@@ -134,14 +134,12 @@ def test_gpt2_optimized(model_name):
).abs().max().item() ).abs().max().item()
@pytest.mark.parametrize("fused_ft_kernel", [False, True])
# @pytest.mark.parametrize('fused_ft_kernel', [True])
@pytest.mark.parametrize("optimized", [False, True]) @pytest.mark.parametrize("optimized", [False, True])
# @pytest.mark.parametrize('optimized', [True]) # @pytest.mark.parametrize('optimized', [True])
@pytest.mark.parametrize("rotary", [False, True]) @pytest.mark.parametrize("rotary", [False, True])
# @pytest.mark.parametrize('rotary', [False]) # @pytest.mark.parametrize('rotary', [False])
@pytest.mark.parametrize("model_name", ["gpt2"]) @pytest.mark.parametrize("model_name", ["gpt2"])
def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel): def test_gpt2_generation(model_name, rotary, optimized):
"""Check that our implementation of GPT2 generation matches the HF implementation: """Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32. the HF scores in fp32.
...@@ -202,18 +200,16 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel): ...@@ -202,18 +200,16 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
out = model.generate( out = model.generate(
input_ids=input_ids, input_ids=input_ids,
max_length=max_length, max_length=max_length,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
enable_timing=True, enable_timing=True,
) )
print(out.sequences) print(out.sequences)
print(tokenizer.batch_decode(out.sequences.tolist())) print(tokenizer.batch_decode(out.sequences.tolist()))
if fused_ft_kernel or getattr(config, "use_flash_attn", False): if getattr(config, "use_flash_attn", False):
out_cg = model.generate( out_cg = model.generate(
input_ids=input_ids, input_ids=input_ids,
max_length=max_length, max_length=max_length,
fused_ft_kernel=fused_ft_kernel,
cg=True, cg=True,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
...@@ -282,10 +278,8 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs): ...@@ -282,10 +278,8 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
# @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)]) # @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)])
@pytest.mark.parametrize("rotary", [None, "interleaved", "contiguous"]) @pytest.mark.parametrize("rotary", [None, "interleaved", "contiguous"])
# @pytest.mark.parametrize('rotary', [None]) # @pytest.mark.parametrize('rotary', [None])
@pytest.mark.parametrize("fused_ft_kernel", [False, True])
# @pytest.mark.parametrize("fused_ft_kernel", [False])
@pytest.mark.parametrize("model_name", ["gpt2"]) @pytest.mark.parametrize("model_name", ["gpt2"])
def test_gpt2_generation_cg(model_name, fused_ft_kernel, rotary, seqlen, maxlen): def test_gpt2_generation_cg(model_name, rotary, seqlen, maxlen):
"""Check that decoding with CUDA graph is the same as decoding without CUDA graph.""" """Check that decoding with CUDA graph is the same as decoding without CUDA graph."""
dtype = torch.float16 dtype = torch.float16
device = "cuda" device = "cuda"
...@@ -315,17 +309,8 @@ def test_gpt2_generation_cg(model_name, fused_ft_kernel, rotary, seqlen, maxlen) ...@@ -315,17 +309,8 @@ def test_gpt2_generation_cg(model_name, fused_ft_kernel, rotary, seqlen, maxlen)
0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device 0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
) )
logits = get_logits( logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
model, input_ids, maxlen, teacher_outputs=teacher_outputs, fused_ft_kernel=fused_ft_kernel logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
)
logits_cg = get_logits(
model,
input_ids,
maxlen,
teacher_outputs=teacher_outputs,
fused_ft_kernel=fused_ft_kernel,
cg=True,
)
assert torch.equal(logits, logits_cg) assert torch.equal(logits, logits_cg)
# Try increasing batch size and seqlen, then decrease them to see if it's still correct # Try increasing batch size and seqlen, then decrease them to see if it's still correct
...@@ -369,7 +354,6 @@ def test_gpt2_multiple_token_generation(model_name, optimized): ...@@ -369,7 +354,6 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_mlp = True config.fused_mlp = True
config.fused_dropout_add_ln = True config.fused_dropout_add_ln = True
# fused_ft_kernel currently doesn't work with multiple tokens at a time
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval() model.eval()
...@@ -398,13 +382,12 @@ def test_gpt2_multiple_token_generation(model_name, optimized): ...@@ -398,13 +382,12 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
assert torch.allclose(logits, logits_ref, rtol=rtol, atol=atol) assert torch.allclose(logits, logits_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("fused_ft_kernel, cg", [(False, False), (True, False), (True, True)]) @pytest.mark.parametrize("cg", [False, True])
# @pytest.mark.parametrize("fused_ft_kernel, cg", [(True, True)])
# @pytest.mark.parametrize("optimized", [False, True]) # @pytest.mark.parametrize("optimized", [False, True])
@pytest.mark.parametrize("optimized", [True]) @pytest.mark.parametrize("optimized", [True])
# @pytest.mark.parametrize("model_name", ["gpt2-medium"]) # @pytest.mark.parametrize("model_name", ["gpt2-medium"])
@pytest.mark.parametrize("model_name", ["gpt2-xl"]) @pytest.mark.parametrize("model_name", ["gpt2-xl"])
def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg): def test_gpt2_speculative_decoding(model_name, optimized, cg):
dtype = torch.float16 dtype = torch.float16
device = "cuda" device = "cuda"
rtol, atol = 3e-3, 3e-1 rtol, atol = 3e-3, 3e-1
...@@ -444,7 +427,6 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg): ...@@ -444,7 +427,6 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
model_draft, model_draft,
max_length=max_length, max_length=max_length,
top_k=5, top_k=5,
fused_ft_kernel=fused_ft_kernel,
cg=cg, cg=cg,
speculative_lookahead=4, speculative_lookahead=4,
enable_timing=True, enable_timing=True,
...@@ -454,7 +436,6 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg): ...@@ -454,7 +436,6 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
input_ids, input_ids,
max_length=max_length, max_length=max_length,
top_k=5, top_k=5,
fused_ft_kernel=fused_ft_kernel,
cg=False, cg=False,
enable_timing=True, enable_timing=True,
return_dict_in_generate=True, return_dict_in_generate=True,
......
...@@ -15,12 +15,10 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHead ...@@ -15,12 +15,10 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHead
# @pytest.mark.parametrize('world_size', [1, 2, 4, 8]) # @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize("world_size", [2])
# @pytest.mark.parametrize('fused_ft_kernel', [False, True]) @pytest.mark.parametrize('rotary', [False, True])
@pytest.mark.parametrize("fused_ft_kernel", [True]) # @pytest.mark.parametrize("rotary", [False])
# @pytest.mark.parametrize('rotary', [False, True])
@pytest.mark.parametrize("rotary", [False])
@pytest.mark.parametrize("model_name", ["gpt2"]) @pytest.mark.parametrize("model_name", ["gpt2"])
def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size): def test_tensor_parallel(model_name, rotary, world_size):
"""Check that our implementation of GPT2 generation matches the HF implementation: """Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32. the HF scores in fp32.
...@@ -111,19 +109,17 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size): ...@@ -111,19 +109,17 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
max_length=max_length, max_length=max_length,
tensor_parallel=world_size, tensor_parallel=world_size,
vocab_size=config.vocab_size, vocab_size=config.vocab_size,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
enable_timing=True, enable_timing=True,
) )
print(out.sequences) print(out.sequences)
if fused_ft_kernel: if getattr(config, "use_flash_attn", False):
out_cg = model.generate( out_cg = model.generate(
input_ids=input_ids, input_ids=input_ids,
max_length=max_length, max_length=max_length,
tensor_parallel=world_size, tensor_parallel=world_size,
vocab_size=config.vocab_size, vocab_size=config.vocab_size,
fused_ft_kernel=fused_ft_kernel,
cg=True, cg=True,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
......
...@@ -83,9 +83,8 @@ def test_gptj_optimized(model_name): ...@@ -83,9 +83,8 @@ def test_gptj_optimized(model_name):
).abs().max().item() ).abs().max().item()
@pytest.mark.parametrize("fused_ft_kernel", [False, True])
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"]) @pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"])
def test_gptj_generation(model_name, fused_ft_kernel): def test_gptj_generation(model_name):
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the """Check that our implementation of GPT-J (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32. forward pass in fp16, when compared to the HF forward pass in fp32.
...@@ -141,7 +140,6 @@ def test_gptj_generation(model_name, fused_ft_kernel): ...@@ -141,7 +140,6 @@ def test_gptj_generation(model_name, fused_ft_kernel):
input_ids=input_ids, input_ids=input_ids,
max_length=max_length, max_length=max_length,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
enable_timing=True, enable_timing=True,
...@@ -152,16 +150,13 @@ def test_gptj_generation(model_name, fused_ft_kernel): ...@@ -152,16 +150,13 @@ def test_gptj_generation(model_name, fused_ft_kernel):
# Capture graph outside the timing loop # Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache( model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=fused_ft_kernel
)
print("With CUDA graph") print("With CUDA graph")
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
out_cg = model.generate( out_cg = model.generate(
input_ids=input_ids, input_ids=input_ids,
max_length=max_length, max_length=max_length,
fused_ft_kernel=fused_ft_kernel,
cg=True, cg=True,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
......
...@@ -292,7 +292,6 @@ def test_llama_generation(model_name, checkpoint_format): ...@@ -292,7 +292,6 @@ def test_llama_generation(model_name, checkpoint_format):
input_ids=input_ids, input_ids=input_ids,
max_length=max_length, max_length=max_length,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
fused_ft_kernel=True,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
enable_timing=True, enable_timing=True,
...@@ -303,16 +302,13 @@ def test_llama_generation(model_name, checkpoint_format): ...@@ -303,16 +302,13 @@ def test_llama_generation(model_name, checkpoint_format):
# Capture graph outside the timing loop # Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache( model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True
)
print("With CUDA graph") print("With CUDA graph")
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
out_cg = model.generate( out_cg = model.generate(
input_ids=input_ids, input_ids=input_ids,
max_length=max_length, max_length=max_length,
fused_ft_kernel=True,
cg=True, cg=True,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
...@@ -401,7 +397,6 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): ...@@ -401,7 +397,6 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
max_length=max_length, max_length=max_length,
tensor_parallel=world_size, tensor_parallel=world_size,
vocab_size=config.vocab_size, vocab_size=config.vocab_size,
fused_ft_kernel=True,
# teacher_outputs=out_hf.sequences, # teacher_outputs=out_hf.sequences,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
...@@ -410,16 +405,13 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): ...@@ -410,16 +405,13 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
# Capture graph outside the timing loop # Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache( model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True
)
print("With CUDA graph") print("With CUDA graph")
out_cg = model.generate( out_cg = model.generate(
input_ids=input_ids, input_ids=input_ids,
max_length=max_length, max_length=max_length,
tensor_parallel=world_size, tensor_parallel=world_size,
vocab_size=config.vocab_size, vocab_size=config.vocab_size,
fused_ft_kernel=True,
cg=True, cg=True,
# teacher_outputs=out_hf.sequences, # teacher_outputs=out_hf.sequences,
return_dict_in_generate=True, return_dict_in_generate=True,
......
...@@ -107,7 +107,6 @@ def test_opt_generation(model_name): ...@@ -107,7 +107,6 @@ def test_opt_generation(model_name):
dtype = torch.float16 dtype = torch.float16
device = "cuda" device = "cuda"
rtol, atol = 3e-3, 3e-1 rtol, atol = 3e-3, 3e-1
fused_ft_kernel = True
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name)) config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
# Only prenorm supports residual_in_fp32 # Only prenorm supports residual_in_fp32
config.residual_in_fp32 = getattr(config, "prenorm", True) config.residual_in_fp32 = getattr(config, "prenorm", True)
...@@ -155,7 +154,6 @@ def test_opt_generation(model_name): ...@@ -155,7 +154,6 @@ def test_opt_generation(model_name):
input_ids=input_ids, input_ids=input_ids,
max_length=max_length, max_length=max_length,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
enable_timing=True, enable_timing=True,
...@@ -165,19 +163,16 @@ def test_opt_generation(model_name): ...@@ -165,19 +163,16 @@ def test_opt_generation(model_name):
if verbose: if verbose:
print(out.sequences) print(out.sequences)
print(tokenizer.batch_decode(out.sequences.tolist())) print(tokenizer.batch_decode(out.sequences.tolist()))
if fused_ft_kernel: if getattr(config, "use_flash_attn", False):
# Capture graph outside the timing loop # Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache( model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True
)
print("With CUDA graph") print("With CUDA graph")
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
out_cg = model.generate( out_cg = model.generate(
input_ids=input_ids, input_ids=input_ids,
max_length=max_length, max_length=max_length,
fused_ft_kernel=fused_ft_kernel,
cg=True, cg=True,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment