Commit dfe29f5e authored by Tri Dao's avatar Tri Dao
Browse files

[Gen] Don't use ft_attention, use flash_attn_with_kvcache instead

parent 3250ff3d
......@@ -6,3 +6,9 @@ FasterTransformer v5.2.1 for benchmarking purpose.
```sh
cd csrc/ft_attention && pip install .
```
As of 2023-09-17, this extension is no longer used in the FlashAttention repo.
FlashAttention now has implemented
[`flash_attn_with_kvcache`](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attention_interface.py)
with all the features of this `ft_attention` kernel (and more).
......@@ -32,11 +32,6 @@ try:
except ImportError:
RotaryEmbedding = None
try:
import ft_attention
except ImportError:
ft_attention = None
class FlashSelfAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
......@@ -314,14 +309,7 @@ def _update_kv_cache(kv, inference_params, layer_idx):
)
inference_params.key_value_memory_dict[layer_idx] = kv_cache
else:
if not inference_params.fused_ft_kernel:
kv_cache = inference_params.key_value_memory_dict[layer_idx]
else:
# For FT, k_cache has shape (b, h, headdim / packsize, s, packsize)
# where packsize = 4 if fp32, 8 if fp16 or bf16.
# v_cache has shape (b, h, s, headdim)
k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
kv_cache = None
kv_cache = inference_params.key_value_memory_dict[layer_idx]
# Adjust key and value for inference
batch_start = inference_params.batch_size_offset
batch_end = batch_start + kv.shape[0]
......@@ -329,79 +317,9 @@ def _update_kv_cache(kv, inference_params, layer_idx):
sequence_end = sequence_start + kv.shape[1]
assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
# Copy key and values.
if not inference_params.fused_ft_kernel:
assert kv_cache is not None
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
return kv
else:
assert inference_params.sequence_len_offset == 0
# FT kernel requires different layouts for the k_cache and v_cache.
assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if kv.dtype == torch.float32 else 8
if kv_cache is not None:
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
k_cache = rearrange(
kv_cache[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize
).contiguous()
v_cache = rearrange(kv_cache[:, :, 1], "b s h d -> b h s d").contiguous()
inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache)
else:
k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize
)
v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(
kv[:, :, 1], "b s h d -> b h s d"
)
return kv
def _apply_rotary_single_query_attention(
qkv,
inference_params,
layer_idx,
rotary_emb_dim,
rotary_emb_base,
kv=None,
rotary_emb_interleaved=False,
):
"""
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
q of shape (batch_size, 1, nheads, head_dim)
kv: (batch_size, 1, 2, nheads_kv, head_dim)
"""
assert inference_params.fused_ft_kernel
assert ft_attention is not None
if kv is None:
q, k, v = rearrange(qkv, "b 1 three h d -> b three h d").unbind(dim=1)
else:
q = rearrange(qkv, "b 1 h d -> b h d")
k, v = rearrange(kv, "b 1 two h d -> b two h d").unbind(dim=1)
batch_start = inference_params.batch_size_offset
batch_end = batch_start + q.shape[0]
k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
lengths_per_sample = (
inference_params.lengths_per_sample[batch_start:batch_end]
if inference_params.lengths_per_sample is not None
else None
)
context = ft_attention.single_query_attention(
q,
k,
v,
k_cache[batch_start:batch_end],
v_cache[batch_start:batch_end],
lengths_per_sample,
None, # rotary_cos_
None, # rotary_sin_
None, # nnz_head_idx
inference_params.sequence_len_offset,
rotary_emb_dim,
rotary_emb_base,
not rotary_emb_interleaved, # neox_rotary_style
)
return rearrange(context, "b h d -> b 1 h d")
assert kv_cache is not None
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
return kv_cache[batch_start:batch_end, :sequence_end, ...]
class MHA(nn.Module):
......@@ -502,36 +420,18 @@ class MHA(nn.Module):
)
self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, fused_ft_kernel=True):
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
dtype = self.out_proj.weight.dtype if dtype is None else dtype
device = self.out_proj.weight.device
if not fused_ft_kernel:
return torch.empty(
batch_size,
max_seqlen,
2,
self.num_heads_kv,
self.head_dim,
dtype=dtype,
device=device,
)
else:
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if dtype == torch.float32 else 8
assert self.head_dim % packsize == 0
k_cache = torch.empty(
batch_size,
self.num_heads_kv,
self.head_dim // packsize,
max_seqlen,
packsize,
dtype=dtype,
device=device,
)
v_cache = torch.empty(
batch_size, self.num_heads_kv, max_seqlen, self.head_dim, dtype=dtype, device=device
)
return k_cache, v_cache
return torch.empty(
batch_size,
max_seqlen,
2,
self.num_heads_kv,
self.head_dim,
dtype=dtype,
device=device,
)
def _update_kv_cache(self, kv, inference_params):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
......@@ -539,27 +439,46 @@ class MHA(nn.Module):
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
return _update_kv_cache(kv, inference_params, self.layer_idx)
def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None):
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
"""
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
q of shape (batch_size, 1, nheads, head_dim)
kv: (batch_size, 1, 2, nheads_kv, head_dim)
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
q: (batch_size, seqlen_q, nheads, head_dim)
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
"""
rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
return _apply_rotary_single_query_attention(
qkv,
inference_params,
self.layer_idx,
self.rotary_emb_dim,
rotary_emb_base,
kv=kv,
rotary_emb_interleaved=self.rotary_emb.interleaved
if self.rotary_emb_dim > 0
else False,
assert inference_params is not None and inference_params.sequence_len_offset > 0
assert self.use_flash_attn
if self.rotary_emb_dim > 0:
assert self.rotary_emb.scale is None, "This code path does not support xPos"
self.rotary_emb._update_cos_sin_cache(
inference_params.max_sequence_len, device=q.device, dtype=q.dtype
)
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
else:
rotary_cos, rotary_sin = None, None
batch = q.shape[0]
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset
)
context = flash_attn_with_kvcache(
q,
kv_cache[:, :, 0],
kv_cache[:, :, 1],
kv[:, :, 0],
kv[:, :, 1],
rotary_cos=rotary_cos,
rotary_sin=rotary_sin,
cache_seqlens=cache_seqlens,
softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal,
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
)
return context
def _update_kvcache_attention(self, q, kv, inference_params):
"""Write kv to inference_params, then do attention """
"""Write kv to inference_params, then do attention"""
if (
inference_params.sequence_len_offset == 0
or flash_attn_with_kvcache is None
......@@ -663,7 +582,8 @@ class MHA(nn.Module):
if (
inference_params is None
or inference_params.sequence_len_offset == 0
or not inference_params.fused_ft_kernel
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
or not self.use_flash_attn
):
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(
......@@ -679,7 +599,9 @@ class MHA(nn.Module):
qkv[:, :, 0], qkv[:, :, 1:], inference_params
)
else:
context = self._apply_rotary_single_query_attention(qkv, inference_params)
context = self._apply_rotary_update_kvcache_attention(
qkv[:, :, 0], qkv[:, :, 1:], inference_params
)
else:
if self.cross_attn:
if not self.return_residual:
......@@ -711,7 +633,8 @@ class MHA(nn.Module):
if (
inference_params is None
or inference_params.sequence_len_offset == 0
or not inference_params.fused_ft_kernel
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
or not self.use_flash_attn
):
if self.rotary_emb_dim > 0:
q, kv = self.rotary_emb(
......@@ -727,7 +650,7 @@ class MHA(nn.Module):
else:
context = self._update_kvcache_attention(q, kv, inference_params)
else:
context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
return out if not self.return_residual else (out, x)
......@@ -825,73 +748,65 @@ class ParallelMHA(nn.Module):
**factory_kwargs,
)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, fused_ft_kernel=True):
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
dtype = self.out_proj.weight.dtype if dtype is None else dtype
device = self.out_proj.weight.device
if not fused_ft_kernel:
return torch.empty(
batch_size,
max_seqlen,
2,
self.num_heads_kv_per_rank,
self.head_dim,
dtype=dtype,
device=device,
)
else:
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if dtype == torch.float32 else 8
assert self.head_dim % packsize == 0
k_cache = torch.empty(
batch_size,
self.num_heads_kv_per_rank,
self.head_dim // packsize,
max_seqlen,
packsize,
dtype=dtype,
device=device,
)
v_cache = torch.empty(
batch_size,
self.num_heads_kv_per_rank,
max_seqlen,
self.head_dim,
dtype=dtype,
device=device,
)
return k_cache, v_cache
return torch.empty(
batch_size,
max_seqlen,
2,
self.num_heads_kv_per_rank,
self.head_dim,
dtype=dtype,
device=device,
)
def _update_kv_cache(self, kv, inference_params):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
return _update_kv_cache(kv, inference_params, self.layer_idx)
def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None):
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
"""
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
q of shape (batch_size, 1, nheads, head_dim)
kv: (batch_size, 1, 2, nheads_kv, head_dim)
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
q: (batch_size, seqlen_q, nheads, head_dim)
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
"""
rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
return _apply_rotary_single_query_attention(
qkv,
inference_params,
self.layer_idx,
self.rotary_emb_dim,
rotary_emb_base,
kv=kv,
rotary_emb_interleaved=self.rotary_emb.interleaved
if self.rotary_emb_dim > 0
else False,
assert inference_params is not None and inference_params.sequence_len_offset > 0
assert self.use_flash_attn
if self.rotary_emb_dim > 0:
assert self.rotary_emb.scale is None, "This code path does not support xPos"
self.rotary_emb._update_cos_sin_cache(
inference_params.max_sequence_len, device=q.device, dtype=q.dtype
)
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
else:
rotary_cos, rotary_sin = None, None
batch = q.shape[0]
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset
)
context = flash_attn_with_kvcache(
q,
kv_cache[:, :, 0],
kv_cache[:, :, 1],
kv[:, :, 0],
kv[:, :, 1],
rotary_cos=rotary_cos,
rotary_sin=rotary_sin,
cache_seqlens=cache_seqlens,
softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal,
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
)
return context
def _update_kvcache_attention(self, q, kv, inference_params):
"""Write kv to inference_params, then do attention """
if (
inference_params.sequence_len_offset == 0
or flash_attn_with_kvcache is None
or not self.use_flash_attn
):
"""Write kv to inference_params, then do attention"""
if inference_params.sequence_len_offset == 0 or not self.use_flash_attn:
# TODO: this only uses sequence_len_offset and not lengths_per_sample.
kv = self._update_kv_cache(kv, inference_params)
return self.inner_cross_attn(q, kv)
......@@ -943,7 +858,8 @@ class ParallelMHA(nn.Module):
if (
inference_params is None
or inference_params.sequence_len_offset == 0
or not inference_params.fused_ft_kernel
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
or not self.use_flash_attn
):
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(
......@@ -959,7 +875,9 @@ class ParallelMHA(nn.Module):
qkv[:, :, 0], qkv[:, :, 1:], inference_params
)
else:
context = self._apply_rotary_single_query_attention(qkv, inference_params)
context = self._apply_rotary_update_kvcache_attention(
qkv[:, :, 0], qkv[:, :, 1:], inference_params
)
else:
q = rearrange(
qkv[..., : self.num_heads_per_rank * self.head_dim],
......@@ -975,7 +893,8 @@ class ParallelMHA(nn.Module):
if (
inference_params is None
or inference_params.sequence_len_offset == 0
or not inference_params.fused_ft_kernel
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
or not self.use_flash_attn
):
if self.rotary_emb_dim > 0:
q, kv = self.rotary_emb(
......@@ -991,7 +910,7 @@ class ParallelMHA(nn.Module):
else:
context = self._update_kvcache_attention(q, kv, inference_params)
else:
context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
context = rearrange(context, "b s h d -> b s (h d)")
if seqlen is not None:
context = rearrange(context, "b s d -> (b s) d")
......
......@@ -25,7 +25,6 @@ class InferenceParams:
sequence_len_offset: int = 0
batch_size_offset: int = 0
key_value_memory_dict: dict = field(default_factory=dict)
fused_ft_kernel: bool = False
lengths_per_sample: Optional[Tensor] = None
......@@ -96,7 +95,6 @@ def decode(
teacher_outputs=None,
vocab_size=None,
tensor_parallel=1,
fused_ft_kernel=False,
cg=False,
enable_timing=False,
):
......@@ -127,7 +125,6 @@ def decode(
seqlen_og,
max_length,
tensor_parallel=tensor_parallel,
fused_ft_kernel=fused_ft_kernel,
)
inference_params = model._decoding_cache.inference_params
inference_params.max_sequence_len = max_length
......@@ -135,9 +132,7 @@ def decode(
inference_params.sequence_len_offset = 0
inference_params.lengths_per_sample.zero_()
else:
inference_params = InferenceParams(
max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel
)
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size)
def get_logits(input_ids, inference_params):
decoding = inference_params.sequence_len_offset > 0
......@@ -273,7 +268,6 @@ def decode_speculative(
eos_token_id=None,
vocab_size=None,
tensor_parallel=1,
fused_ft_kernel=False,
cg=False,
enable_timing=False,
debug=False,
......@@ -307,23 +301,17 @@ def decode_speculative(
seqlen_og,
max_length,
tensor_parallel=tensor_parallel,
fused_ft_kernel=fused_ft_kernel,
)
inference_params_draft = model_draft._decoding_cache.inference_params
inference_params_draft.max_sequence_len = max_length
inference_params_draft.max_batch_size = batch_size
inference_params_draft.sequence_len_offset = 0
# fused_ft_kernel doesn't support passing in multiple tokens at once
inference_params = InferenceParams(
max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=False
)
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size)
else:
inference_params_draft = InferenceParams(
max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel
)
inference_params = InferenceParams(
max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=False
max_sequence_len=max_length, max_batch_size=batch_size
)
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size)
def logits_forward_fn(model, input_ids, position_ids, inference_params, cg=False):
if not cg:
......@@ -606,7 +594,6 @@ def allocate_inference_cache(
layers: Union[int, Sequence],
device,
dtype=torch.float16,
fused_ft_kernel=False,
):
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if dtype == torch.float32 else 8
......@@ -616,15 +603,7 @@ def allocate_inference_cache(
kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim)
if isinstance(layers, int):
layers = range(layers)
return {
i: (
torch.empty(k_cache_shape, device=device, dtype=dtype),
torch.empty(v_cache_shape, device=device, dtype=dtype),
)
if fused_ft_kernel
else torch.empty(kv_cache_sahpe, device=device, dtype=dtype)
for i in layers
}
return {i: torch.empty(kv_cache_shape, device=device, dtype=dtype) for i in layers}
def seqlen_to_seqlen_type(seqlen: int) -> int:
......@@ -633,12 +612,12 @@ def seqlen_to_seqlen_type(seqlen: int) -> int:
Arguments:
seqlen: int
"""
return 0 if seqlen < 32 else (1 if seqlen < 2048 else 2)
return 0
def seqlen_type_to_max_seqlen(seqlen_type: int) -> int:
assert seqlen_type in [0, 1, 2]
return 32 if seqlen_type == 0 else (2048 if seqlen_type == 1 else 2**32)
assert seqlen_type in [0]
return 2**32
@dataclass
......@@ -663,7 +642,6 @@ def update_graph_cache(
tensor_parallel=1,
dtype=None,
n_warmups=2,
fused_ft_kernel=False,
):
if cache is None:
cache = DecodingCGCache()
......@@ -683,9 +661,7 @@ def update_graph_cache(
cache.device, cache.dtype = device, dtype
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
if hasattr(model, "allocate_inference_cache"):
inf_cache = model.allocate_inference_cache(
batch_size, max_seqlen, dtype, fused_ft_kernel=fused_ft_kernel
)
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
else:
headdim = getattr(
model.config,
......@@ -700,7 +676,6 @@ def update_graph_cache(
model.config.num_hidden_layers,
device,
dtype,
fused_ft_kernel=fused_ft_kernel,
)
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
cache.inference_params = InferenceParams(
......@@ -708,7 +683,6 @@ def update_graph_cache(
max_batch_size=batch_size,
sequence_len_offset=seqlen_og,
key_value_memory_dict=inf_cache,
fused_ft_kernel=fused_ft_kernel,
lengths_per_sample=lengths_per_sample,
)
cache.mempool = torch.cuda.graphs.graph_pool_handle()
......
......@@ -122,10 +122,10 @@ if not SKIP_CUDA_BUILD:
# cc_flag.append("arch=compute_75,code=sm_75")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
if CUDA_HOME is not None:
if bare_metal_version >= Version("11.8"):
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")
# if CUDA_HOME is not None:
# if bare_metal_version >= Version("11.8"):
# cc_flag.append("-gencode")
# cc_flag.append("arch=compute_90,code=sm_90")
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
# torch._C._GLIBCXX_USE_CXX11_ABI
......
......@@ -217,9 +217,8 @@ def test_baichuan_parallel_forward(model_name, world_size):
).abs().max().item()
@pytest.mark.parametrize("fused_ft_kernel", [False, True])
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"])
def test_baichuan_generation(model_name, fused_ft_kernel):
def test_baichuan_generation(model_name):
dtype = torch.float16
device = "cuda"
config = baichuan_config_to_gpt2_config(
......@@ -236,8 +235,8 @@ def test_baichuan_generation(model_name, fused_ft_kernel):
torch.manual_seed(0)
batch_size = 1
seqlen = 100
max_length = 150
seqlen = 2048
max_length = 2048 + 150
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
......@@ -285,7 +284,6 @@ def test_baichuan_generation(model_name, fused_ft_kernel):
input_ids=input_ids,
max_length=max_length,
eos_token_id=eos_token_id,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
......@@ -296,16 +294,13 @@ def test_baichuan_generation(model_name, fused_ft_kernel):
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=fused_ft_kernel
)
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print("With CUDA graph")
torch.cuda.synchronize()
start = time.time()
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
fused_ft_kernel=fused_ft_kernel,
cg=True,
return_dict_in_generate=True,
output_scores=True,
......@@ -403,9 +398,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=False
)
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print("With CUDA graph")
out_cg = model.generate(
input_ids=input_ids,
......
......@@ -141,7 +141,6 @@ def test_bigcode_generation(model_name):
input_ids=input_ids,
max_length=max_length,
eos_token_id=eos_token_id,
fused_ft_kernel=True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
......@@ -159,7 +158,6 @@ def test_bigcode_generation(model_name):
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
fused_ft_kernel=True,
cg=True,
return_dict_in_generate=True,
output_scores=True,
......
......@@ -242,7 +242,6 @@ def test_falcon_generation(model_name):
input_ids=input_ids,
max_length=max_length,
eos_token_id=eos_token_id,
fused_ft_kernel=True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
......@@ -253,16 +252,13 @@ def test_falcon_generation(model_name):
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True
)
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print("With CUDA graph")
torch.cuda.synchronize()
start = time.time()
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
fused_ft_kernel=True,
cg=True,
return_dict_in_generate=True,
output_scores=True,
......@@ -349,7 +345,6 @@ def test_falcon_parallel_generation(model_name, world_size):
max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
fused_ft_kernel=True,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate=True,
output_scores=True,
......@@ -358,16 +353,13 @@ def test_falcon_parallel_generation(model_name, world_size):
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True
)
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print("With CUDA graph")
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
fused_ft_kernel=True,
cg=True,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate=True,
......
......@@ -134,14 +134,12 @@ def test_gpt2_optimized(model_name):
).abs().max().item()
@pytest.mark.parametrize("fused_ft_kernel", [False, True])
# @pytest.mark.parametrize('fused_ft_kernel', [True])
@pytest.mark.parametrize("optimized", [False, True])
# @pytest.mark.parametrize('optimized', [True])
@pytest.mark.parametrize("rotary", [False, True])
# @pytest.mark.parametrize('rotary', [False])
@pytest.mark.parametrize("model_name", ["gpt2"])
def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
def test_gpt2_generation(model_name, rotary, optimized):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
......@@ -202,18 +200,16 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
out = model.generate(
input_ids=input_ids,
max_length=max_length,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
)
print(out.sequences)
print(tokenizer.batch_decode(out.sequences.tolist()))
if fused_ft_kernel or getattr(config, "use_flash_attn", False):
if getattr(config, "use_flash_attn", False):
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
fused_ft_kernel=fused_ft_kernel,
cg=True,
return_dict_in_generate=True,
output_scores=True,
......@@ -282,10 +278,8 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
# @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)])
@pytest.mark.parametrize("rotary", [None, "interleaved", "contiguous"])
# @pytest.mark.parametrize('rotary', [None])
@pytest.mark.parametrize("fused_ft_kernel", [False, True])
# @pytest.mark.parametrize("fused_ft_kernel", [False])
@pytest.mark.parametrize("model_name", ["gpt2"])
def test_gpt2_generation_cg(model_name, fused_ft_kernel, rotary, seqlen, maxlen):
def test_gpt2_generation_cg(model_name, rotary, seqlen, maxlen):
"""Check that decoding with CUDA graph is the same as decoding without CUDA graph."""
dtype = torch.float16
device = "cuda"
......@@ -315,17 +309,8 @@ def test_gpt2_generation_cg(model_name, fused_ft_kernel, rotary, seqlen, maxlen)
0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
)
logits = get_logits(
model, input_ids, maxlen, teacher_outputs=teacher_outputs, fused_ft_kernel=fused_ft_kernel
)
logits_cg = get_logits(
model,
input_ids,
maxlen,
teacher_outputs=teacher_outputs,
fused_ft_kernel=fused_ft_kernel,
cg=True,
)
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
assert torch.equal(logits, logits_cg)
# Try increasing batch size and seqlen, then decrease them to see if it's still correct
......@@ -369,7 +354,6 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
# fused_ft_kernel currently doesn't work with multiple tokens at a time
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval()
......@@ -398,13 +382,12 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
assert torch.allclose(logits, logits_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("fused_ft_kernel, cg", [(False, False), (True, False), (True, True)])
# @pytest.mark.parametrize("fused_ft_kernel, cg", [(True, True)])
@pytest.mark.parametrize("cg", [False, True])
# @pytest.mark.parametrize("optimized", [False, True])
@pytest.mark.parametrize("optimized", [True])
# @pytest.mark.parametrize("model_name", ["gpt2-medium"])
@pytest.mark.parametrize("model_name", ["gpt2-xl"])
def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
def test_gpt2_speculative_decoding(model_name, optimized, cg):
dtype = torch.float16
device = "cuda"
rtol, atol = 3e-3, 3e-1
......@@ -444,7 +427,6 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
model_draft,
max_length=max_length,
top_k=5,
fused_ft_kernel=fused_ft_kernel,
cg=cg,
speculative_lookahead=4,
enable_timing=True,
......@@ -454,7 +436,6 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
input_ids,
max_length=max_length,
top_k=5,
fused_ft_kernel=fused_ft_kernel,
cg=False,
enable_timing=True,
return_dict_in_generate=True,
......
......@@ -15,12 +15,10 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHead
# @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize("world_size", [2])
# @pytest.mark.parametrize('fused_ft_kernel', [False, True])
@pytest.mark.parametrize("fused_ft_kernel", [True])
# @pytest.mark.parametrize('rotary', [False, True])
@pytest.mark.parametrize("rotary", [False])
@pytest.mark.parametrize('rotary', [False, True])
# @pytest.mark.parametrize("rotary", [False])
@pytest.mark.parametrize("model_name", ["gpt2"])
def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
def test_tensor_parallel(model_name, rotary, world_size):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
......@@ -111,19 +109,17 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
)
print(out.sequences)
if fused_ft_kernel:
if getattr(config, "use_flash_attn", False):
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
fused_ft_kernel=fused_ft_kernel,
cg=True,
return_dict_in_generate=True,
output_scores=True,
......
......@@ -83,9 +83,8 @@ def test_gptj_optimized(model_name):
).abs().max().item()
@pytest.mark.parametrize("fused_ft_kernel", [False, True])
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"])
def test_gptj_generation(model_name, fused_ft_kernel):
def test_gptj_generation(model_name):
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
......@@ -141,7 +140,6 @@ def test_gptj_generation(model_name, fused_ft_kernel):
input_ids=input_ids,
max_length=max_length,
eos_token_id=eos_token_id,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
......@@ -152,16 +150,13 @@ def test_gptj_generation(model_name, fused_ft_kernel):
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=fused_ft_kernel
)
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print("With CUDA graph")
torch.cuda.synchronize()
start = time.time()
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
fused_ft_kernel=fused_ft_kernel,
cg=True,
return_dict_in_generate=True,
output_scores=True,
......
......@@ -292,7 +292,6 @@ def test_llama_generation(model_name, checkpoint_format):
input_ids=input_ids,
max_length=max_length,
eos_token_id=eos_token_id,
fused_ft_kernel=True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
......@@ -303,16 +302,13 @@ def test_llama_generation(model_name, checkpoint_format):
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True
)
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print("With CUDA graph")
torch.cuda.synchronize()
start = time.time()
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
fused_ft_kernel=True,
cg=True,
return_dict_in_generate=True,
output_scores=True,
......@@ -401,7 +397,6 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
fused_ft_kernel=True,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate=True,
output_scores=True,
......@@ -410,16 +405,13 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True
)
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print("With CUDA graph")
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
fused_ft_kernel=True,
cg=True,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate=True,
......
......@@ -107,7 +107,6 @@ def test_opt_generation(model_name):
dtype = torch.float16
device = "cuda"
rtol, atol = 3e-3, 3e-1
fused_ft_kernel = True
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
# Only prenorm supports residual_in_fp32
config.residual_in_fp32 = getattr(config, "prenorm", True)
......@@ -155,7 +154,6 @@ def test_opt_generation(model_name):
input_ids=input_ids,
max_length=max_length,
eos_token_id=eos_token_id,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
......@@ -165,19 +163,16 @@ def test_opt_generation(model_name):
if verbose:
print(out.sequences)
print(tokenizer.batch_decode(out.sequences.tolist()))
if fused_ft_kernel:
if getattr(config, "use_flash_attn", False):
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=True
)
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print("With CUDA graph")
torch.cuda.synchronize()
start = time.time()
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
fused_ft_kernel=fused_ft_kernel,
cg=True,
return_dict_in_generate=True,
output_scores=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment