Commit a86442f0 authored by Tri Dao's avatar Tri Dao
Browse files

[Gen] Use flash_attn_with_kvcache in generation

parent a1576ad1
...@@ -146,7 +146,8 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function): ...@@ -146,7 +146,8 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
# Call 1 kernel instead of 2 kernels # Call 1 kernel instead of 2 kernels
# We need qkv to be contiguous so that when we reshape to combine (3, nheads) # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
# dimensions, we get the same tensor # dimensions, we get the same tensor
qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d") # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
apply_rotary( apply_rotary(
qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
) )
......
...@@ -15,10 +15,12 @@ try: ...@@ -15,10 +15,12 @@ try:
flash_attn_qkvpacked_func, flash_attn_qkvpacked_func,
flash_attn_varlen_kvpacked_func, flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_qkvpacked_func,
flash_attn_with_kvcache,
) )
except ImportError: except ImportError:
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
flash_attn_with_kvcache = None
try: try:
from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
...@@ -556,6 +558,35 @@ class MHA(nn.Module): ...@@ -556,6 +558,35 @@ class MHA(nn.Module):
else False, else False,
) )
def _update_kvcache_attention(self, q, kv, inference_params):
"""Write kv to inference_params, then do attention """
if (
inference_params.sequence_len_offset == 0
or flash_attn_with_kvcache is None
or not self.use_flash_attn
):
# TODO: this only uses sequence_len_offset and not lengths_per_sample.
kv = self._update_kv_cache(kv, inference_params)
return self.inner_cross_attn(q, kv)
else:
batch = q.shape[0]
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset
)
return flash_attn_with_kvcache(
q,
kv_cache[:, :, 0],
kv_cache[:, :, 1],
kv[:, :, 0],
kv[:, :, 1],
cache_seqlens=cache_seqlens,
softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal,
)
def forward( def forward(
self, self,
x, x,
...@@ -605,10 +636,19 @@ class MHA(nn.Module): ...@@ -605,10 +636,19 @@ class MHA(nn.Module):
if self.use_flash_attn if self.use_flash_attn
else {"key_padding_mask": key_padding_mask, **kwargs} else {"key_padding_mask": key_padding_mask, **kwargs}
) )
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset seqlen_offset = (
0
if inference_params is None
else (
inference_params.lengths_per_sample
if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset
)
)
rotary_max_seqlen = ( rotary_max_seqlen = (
inference_params.max_sequence_len if inference_params is not None else None inference_params.max_sequence_len if inference_params is not None else None
) )
batch, seqlen = x.shape[:2]
if not self.cross_attn and self.num_heads_kv == self.num_heads: if not self.cross_attn and self.num_heads_kv == self.num_heads:
assert x_kv is None and mixer_subset is None assert x_kv is None and mixer_subset is None
if not self.return_residual: if not self.return_residual:
...@@ -619,7 +659,8 @@ class MHA(nn.Module): ...@@ -619,7 +659,8 @@ class MHA(nn.Module):
qkv = rearrange( qkv = rearrange(
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d" self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
).contiguous() ).contiguous()
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) # qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
qkv = qkv.reshape(batch, seqlen, 3, self.num_heads, self.head_dim)
if ( if (
inference_params is None inference_params is None
or inference_params.sequence_len_offset == 0 or inference_params.sequence_len_offset == 0
...@@ -635,9 +676,9 @@ class MHA(nn.Module): ...@@ -635,9 +676,9 @@ class MHA(nn.Module):
else: else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
else: else:
q = qkv[:, :, 0] context = self._update_kvcache_attention(
kv = self._update_kv_cache(qkv[:, :, 1:], inference_params) qkv[:, :, 0], qkv[:, :, 1:], inference_params
context = self.inner_cross_attn(q, kv) )
else: else:
context = self._apply_rotary_single_query_attention(qkv, inference_params) context = self._apply_rotary_single_query_attention(qkv, inference_params)
else: else:
...@@ -659,8 +700,10 @@ class MHA(nn.Module): ...@@ -659,8 +700,10 @@ class MHA(nn.Module):
qkv, x = self.Wqkv(x) qkv, x = self.Wqkv(x)
q = qkv[..., : self.num_heads * self.head_dim] q = qkv[..., : self.num_heads * self.head_dim]
kv = qkv[..., self.num_heads * self.head_dim :] kv = qkv[..., self.num_heads * self.head_dim :]
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) # q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim) q = q.reshape(batch, seqlen, -1, self.head_dim)
# kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
kv = kv.reshape(batch, seqlen, 2, -1, self.head_dim)
if self.dwconv: if self.dwconv:
q = rearrange( q = rearrange(
self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d" self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
...@@ -685,11 +728,11 @@ class MHA(nn.Module): ...@@ -685,11 +728,11 @@ class MHA(nn.Module):
self.inner_cross_attn, q, kv, **kwargs self.inner_cross_attn, q, kv, **kwargs
) )
else: else:
kv = self._update_kv_cache(kv, inference_params) context = self._update_kvcache_attention(q, kv, inference_params)
context = self.inner_cross_attn(q, kv)
else: else:
context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv) context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
out = self.out_proj(rearrange(context, "... h d -> ... (h d)")) # out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
out = self.out_proj(context.reshape(batch, seqlen, -1))
return out if not self.return_residual else (out, x) return out if not self.return_residual else (out, x)
...@@ -846,6 +889,36 @@ class ParallelMHA(nn.Module): ...@@ -846,6 +889,36 @@ class ParallelMHA(nn.Module):
else False, else False,
) )
def _update_kvcache_attention(self, q, kv, inference_params):
"""Write kv to inference_params, then do attention """
if (
inference_params.sequence_len_offset == 0
or flash_attn_with_kvcache is None
or not self.use_flash_attn
):
# TODO: this only uses sequence_len_offset and not lengths_per_sample.
kv = self._update_kv_cache(kv, inference_params)
return self.inner_cross_attn(q, kv)
else:
batch = q.shape[0]
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset
)
context = flash_attn_with_kvcache(
q,
kv_cache[:, :, 0],
kv_cache[:, :, 1],
kv[:, :, 0],
kv[:, :, 1],
cache_seqlens=cache_seqlens,
softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal,
)
return context
def forward(self, x, seqlen=None, inference_params=None, **kwargs): def forward(self, x, seqlen=None, inference_params=None, **kwargs):
""" """
Arguments: Arguments:
...@@ -857,7 +930,15 @@ class ParallelMHA(nn.Module): ...@@ -857,7 +930,15 @@ class ParallelMHA(nn.Module):
qkv = self.Wqkv(x) qkv = self.Wqkv(x)
if seqlen is not None: if seqlen is not None:
qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen) qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset seqlen_offset = (
0
if inference_params is None
else (
inference_params.lengths_per_sample
if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset
)
)
rotary_max_seqlen = ( rotary_max_seqlen = (
inference_params.max_sequence_len if inference_params is not None else None inference_params.max_sequence_len if inference_params is not None else None
) )
...@@ -878,9 +959,9 @@ class ParallelMHA(nn.Module): ...@@ -878,9 +959,9 @@ class ParallelMHA(nn.Module):
else: else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
else: else:
q = qkv[:, :, 0] context = self._update_kvcache_attention(
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx) qkv[:, :, 0], qkv[:, :, 1:], inference_params
context = self.inner_cross_attn(q, kv) )
else: else:
context = self._apply_rotary_single_query_attention(qkv, inference_params) context = self._apply_rotary_single_query_attention(qkv, inference_params)
else: else:
...@@ -912,8 +993,7 @@ class ParallelMHA(nn.Module): ...@@ -912,8 +993,7 @@ class ParallelMHA(nn.Module):
self.inner_cross_attn, q, kv, **kwargs self.inner_cross_attn, q, kv, **kwargs
) )
else: else:
kv = self._update_kv_cache(kv, inference_params) context = self._update_kvcache_attention(q, kv, inference_params)
context = self.inner_cross_attn(q, kv)
else: else:
context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv) context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
context = rearrange(context, "b s h d -> b s (h d)") context = rearrange(context, "b s h d -> b s (h d)")
......
...@@ -118,7 +118,6 @@ def decode( ...@@ -118,7 +118,6 @@ def decode(
batch_size, seqlen_og = input_ids.shape batch_size, seqlen_og = input_ids.shape
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0 teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
if cg: if cg:
assert fused_ft_kernel
if not hasattr(model, "_decoding_cache"): if not hasattr(model, "_decoding_cache"):
model._decoding_cache = None model._decoding_cache = None
model._decoding_cache = update_graph_cache( model._decoding_cache = update_graph_cache(
...@@ -128,11 +127,13 @@ def decode( ...@@ -128,11 +127,13 @@ def decode(
seqlen_og, seqlen_og,
max_length, max_length,
tensor_parallel=tensor_parallel, tensor_parallel=tensor_parallel,
fused_ft_kernel=fused_ft_kernel,
) )
inference_params = model._decoding_cache.inference_params inference_params = model._decoding_cache.inference_params
inference_params.max_sequence_len = max_length inference_params.max_sequence_len = max_length
inference_params.max_batch_size = batch_size inference_params.max_batch_size = batch_size
inference_params.sequence_len_offset = 0 inference_params.sequence_len_offset = 0
inference_params.lengths_per_sample.zero_()
else: else:
inference_params = InferenceParams( inference_params = InferenceParams(
max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel
...@@ -167,7 +168,8 @@ def decode( ...@@ -167,7 +168,8 @@ def decode(
token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
else: else:
token = teacher_outputs[:, inference_params.sequence_len_offset] token = teacher_outputs[:, inference_params.sequence_len_offset]
return rearrange(token, "b -> b 1") # return rearrange(token, "b -> b 1")
return token.unsqueeze(1)
def should_stop(current_token, inference_params): def should_stop(current_token, inference_params):
if inference_params.sequence_len_offset == 0: if inference_params.sequence_len_offset == 0:
...@@ -197,9 +199,7 @@ def decode( ...@@ -197,9 +199,7 @@ def decode(
torch.cuda.synchronize() torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms") print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
return output_cls( return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
sequences=torch.cat(sequences, dim=1), scores=tuple(scores)
)
def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, temperature=1.0): def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, temperature=1.0):
...@@ -298,7 +298,6 @@ def decode_speculative( ...@@ -298,7 +298,6 @@ def decode_speculative(
assert batch_size == 1, "Speculative decoding implementation only supports batch_size=1" assert batch_size == 1, "Speculative decoding implementation only supports batch_size=1"
assert eos_token_id is None, "Speculative decoding implementation doesn't support eos_token_id" assert eos_token_id is None, "Speculative decoding implementation doesn't support eos_token_id"
if cg: if cg:
assert fused_ft_kernel
if not hasattr(model_draft, "_decoding_cache"): if not hasattr(model_draft, "_decoding_cache"):
model_draft._decoding_cache = None model_draft._decoding_cache = None
model_draft._decoding_cache = update_graph_cache( model_draft._decoding_cache = update_graph_cache(
...@@ -308,6 +307,7 @@ def decode_speculative( ...@@ -308,6 +307,7 @@ def decode_speculative(
seqlen_og, seqlen_og,
max_length, max_length,
tensor_parallel=tensor_parallel, tensor_parallel=tensor_parallel,
fused_ft_kernel=fused_ft_kernel,
) )
inference_params_draft = model_draft._decoding_cache.inference_params inference_params_draft = model_draft._decoding_cache.inference_params
inference_params_draft.max_sequence_len = max_length inference_params_draft.max_sequence_len = max_length
...@@ -606,12 +606,14 @@ def allocate_inference_cache( ...@@ -606,12 +606,14 @@ def allocate_inference_cache(
layers: Union[int, Sequence], layers: Union[int, Sequence],
device, device,
dtype=torch.float16, dtype=torch.float16,
fused_ft_kernel=False,
): ):
assert dtype in [torch.float16, torch.bfloat16, torch.float32] assert dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if dtype == torch.float32 else 8 packsize = 4 if dtype == torch.float32 else 8
assert headdim % packsize == 0 assert headdim % packsize == 0
k_cache_shape = (max_batch_size, nheads, headdim // packsize, max_seqlen, packsize) k_cache_shape = (max_batch_size, nheads, headdim // packsize, max_seqlen, packsize)
v_cache_shape = (max_batch_size, nheads, max_seqlen, headdim) v_cache_shape = (max_batch_size, nheads, max_seqlen, headdim)
kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim)
if isinstance(layers, int): if isinstance(layers, int):
layers = range(layers) layers = range(layers)
return { return {
...@@ -619,6 +621,8 @@ def allocate_inference_cache( ...@@ -619,6 +621,8 @@ def allocate_inference_cache(
torch.empty(k_cache_shape, device=device, dtype=dtype), torch.empty(k_cache_shape, device=device, dtype=dtype),
torch.empty(v_cache_shape, device=device, dtype=dtype), torch.empty(v_cache_shape, device=device, dtype=dtype),
) )
if fused_ft_kernel
else torch.empty(kv_cache_sahpe, device=device, dtype=dtype)
for i in layers for i in layers
} }
...@@ -651,7 +655,15 @@ class DecodingCGCache: ...@@ -651,7 +655,15 @@ class DecodingCGCache:
@torch.inference_mode() @torch.inference_mode()
def update_graph_cache( def update_graph_cache(
model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1, dtype=None, n_warmups=2 model,
cache,
batch_size,
seqlen_og,
max_seqlen,
tensor_parallel=1,
dtype=None,
n_warmups=2,
fused_ft_kernel=False,
): ):
if cache is None: if cache is None:
cache = DecodingCGCache() cache = DecodingCGCache()
...@@ -671,7 +683,9 @@ def update_graph_cache( ...@@ -671,7 +683,9 @@ def update_graph_cache(
cache.device, cache.dtype = device, dtype cache.device, cache.dtype = device, dtype
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
if hasattr(model, "allocate_inference_cache"): if hasattr(model, "allocate_inference_cache"):
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype) inf_cache = model.allocate_inference_cache(
batch_size, max_seqlen, dtype, fused_ft_kernel=fused_ft_kernel
)
else: else:
headdim = getattr( headdim = getattr(
model.config, model.config,
...@@ -686,6 +700,7 @@ def update_graph_cache( ...@@ -686,6 +700,7 @@ def update_graph_cache(
model.config.num_hidden_layers, model.config.num_hidden_layers,
device, device,
dtype, dtype,
fused_ft_kernel=fused_ft_kernel,
) )
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
cache.inference_params = InferenceParams( cache.inference_params = InferenceParams(
...@@ -693,7 +708,7 @@ def update_graph_cache( ...@@ -693,7 +708,7 @@ def update_graph_cache(
max_batch_size=batch_size, max_batch_size=batch_size,
sequence_len_offset=seqlen_og, sequence_len_offset=seqlen_og,
key_value_memory_dict=inf_cache, key_value_memory_dict=inf_cache,
fused_ft_kernel=True, fused_ft_kernel=fused_ft_kernel,
lengths_per_sample=lengths_per_sample, lengths_per_sample=lengths_per_sample,
) )
cache.mempool = torch.cuda.graphs.graph_pool_handle() cache.mempool = torch.cuda.graphs.graph_pool_handle()
......
...@@ -217,8 +217,9 @@ def test_baichuan_parallel_forward(model_name, world_size): ...@@ -217,8 +217,9 @@ def test_baichuan_parallel_forward(model_name, world_size):
).abs().max().item() ).abs().max().item()
@pytest.mark.parametrize("fused_ft_kernel", [False, True])
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"]) @pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"])
def test_baichuan_generation(model_name): def test_baichuan_generation(model_name, fused_ft_kernel):
dtype = torch.float16 dtype = torch.float16
device = "cuda" device = "cuda"
config = baichuan_config_to_gpt2_config( config = baichuan_config_to_gpt2_config(
...@@ -276,6 +277,7 @@ def test_baichuan_generation(model_name): ...@@ -276,6 +277,7 @@ def test_baichuan_generation(model_name):
model.load_state_dict(pretrained_state_dict) model.load_state_dict(pretrained_state_dict)
model.eval() model.eval()
model(input_ids) # Warm up
print("Without CUDA graph") print("Without CUDA graph")
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
...@@ -283,7 +285,7 @@ def test_baichuan_generation(model_name): ...@@ -283,7 +285,7 @@ def test_baichuan_generation(model_name):
input_ids=input_ids, input_ids=input_ids,
max_length=max_length, max_length=max_length,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
fused_ft_kernel=True, fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
enable_timing=True, enable_timing=True,
...@@ -295,7 +297,7 @@ def test_baichuan_generation(model_name): ...@@ -295,7 +297,7 @@ def test_baichuan_generation(model_name):
# Capture graph outside the timing loop # Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache( model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=fused_ft_kernel
) )
print("With CUDA graph") print("With CUDA graph")
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -303,7 +305,7 @@ def test_baichuan_generation(model_name): ...@@ -303,7 +305,7 @@ def test_baichuan_generation(model_name):
out_cg = model.generate( out_cg = model.generate(
input_ids=input_ids, input_ids=input_ids,
max_length=max_length, max_length=max_length,
fused_ft_kernel=True, fused_ft_kernel=fused_ft_kernel,
cg=True, cg=True,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
...@@ -346,7 +348,7 @@ def test_baichuan_parallel_generation(model_name, world_size): ...@@ -346,7 +348,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
config = baichuan_config_to_gpt2_config( config = baichuan_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True) AutoConfig.from_pretrained(model_name, trust_remote_code=True)
) )
config.use_flash_attn = False config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = False config.fused_dropout_add_ln = False
...@@ -393,7 +395,6 @@ def test_baichuan_parallel_generation(model_name, world_size): ...@@ -393,7 +395,6 @@ def test_baichuan_parallel_generation(model_name, world_size):
max_length=max_length, max_length=max_length,
tensor_parallel=world_size, tensor_parallel=world_size,
vocab_size=config.vocab_size, vocab_size=config.vocab_size,
fused_ft_kernel=True,
# teacher_outputs=out_hf.sequences, # teacher_outputs=out_hf.sequences,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
...@@ -411,7 +412,6 @@ def test_baichuan_parallel_generation(model_name, world_size): ...@@ -411,7 +412,6 @@ def test_baichuan_parallel_generation(model_name, world_size):
max_length=max_length, max_length=max_length,
tensor_parallel=world_size, tensor_parallel=world_size,
vocab_size=config.vocab_size, vocab_size=config.vocab_size,
fused_ft_kernel=True,
cg=True, cg=True,
# teacher_outputs=out_hf.sequences, # teacher_outputs=out_hf.sequences,
return_dict_in_generate=True, return_dict_in_generate=True,
...@@ -458,6 +458,6 @@ def test_baichuan_parallel_generation(model_name, world_size): ...@@ -458,6 +458,6 @@ def test_baichuan_parallel_generation(model_name, world_size):
hf_error = (logits_hf - logits_ref).abs().max().item() hf_error = (logits_hf - logits_ref).abs().max().item()
print(f"HF fp16 logits max diff: {hf_error}") print(f"HF fp16 logits max diff: {hf_error}")
print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }") print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }") print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
assert torch.equal(logits_cg, logits) assert torch.equal(logits_cg, logits)
...@@ -135,7 +135,7 @@ def test_gpt2_optimized(model_name): ...@@ -135,7 +135,7 @@ def test_gpt2_optimized(model_name):
@pytest.mark.parametrize("fused_ft_kernel", [False, True]) @pytest.mark.parametrize("fused_ft_kernel", [False, True])
# @pytest.mark.parametrize('fused_ft_kernel', [False]) # @pytest.mark.parametrize('fused_ft_kernel', [True])
@pytest.mark.parametrize("optimized", [False, True]) @pytest.mark.parametrize("optimized", [False, True])
# @pytest.mark.parametrize('optimized', [True]) # @pytest.mark.parametrize('optimized', [True])
@pytest.mark.parametrize("rotary", [False, True]) @pytest.mark.parametrize("rotary", [False, True])
...@@ -209,7 +209,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel): ...@@ -209,7 +209,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
) )
print(out.sequences) print(out.sequences)
print(tokenizer.batch_decode(out.sequences.tolist())) print(tokenizer.batch_decode(out.sequences.tolist()))
if fused_ft_kernel or config.use_flash_attn: if fused_ft_kernel or getattr(config, "use_flash_attn", False):
out_cg = model.generate( out_cg = model.generate(
input_ids=input_ids, input_ids=input_ids,
max_length=max_length, max_length=max_length,
...@@ -220,6 +220,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel): ...@@ -220,6 +220,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
enable_timing=True, enable_timing=True,
) )
print(out_cg.sequences) print(out_cg.sequences)
assert torch.equal(torch.stack(out.scores, dim=1), torch.stack(out_cg.scores, dim=1))
if not rotary: if not rotary:
out_hf = model_hf.generate( out_hf = model_hf.generate(
...@@ -282,6 +283,7 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs): ...@@ -282,6 +283,7 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
@pytest.mark.parametrize("rotary", [None, "interleaved", "block"]) @pytest.mark.parametrize("rotary", [None, "interleaved", "block"])
# @pytest.mark.parametrize('rotary', [None]) # @pytest.mark.parametrize('rotary', [None])
@pytest.mark.parametrize("fused_ft_kernel", [False, True]) @pytest.mark.parametrize("fused_ft_kernel", [False, True])
# @pytest.mark.parametrize("fused_ft_kernel", [False])
@pytest.mark.parametrize("model_name", ["gpt2"]) @pytest.mark.parametrize("model_name", ["gpt2"])
def test_gpt2_generation_cg(model_name, fused_ft_kernel, rotary, seqlen, maxlen): def test_gpt2_generation_cg(model_name, fused_ft_kernel, rotary, seqlen, maxlen):
"""Check that decoding with CUDA graph is the same as decoding without CUDA graph.""" """Check that decoding with CUDA graph is the same as decoding without CUDA graph."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment