Commit a86442f0 authored by Tri Dao's avatar Tri Dao
Browse files

[Gen] Use flash_attn_with_kvcache in generation

parent a1576ad1
......@@ -146,7 +146,8 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
# Call 1 kernel instead of 2 kernels
# We need qkv to be contiguous so that when we reshape to combine (3, nheads)
# dimensions, we get the same tensor
qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
# qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
apply_rotary(
qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
)
......
......@@ -15,10 +15,12 @@ try:
flash_attn_qkvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
flash_attn_with_kvcache,
)
except ImportError:
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
flash_attn_with_kvcache = None
try:
from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
......@@ -556,6 +558,35 @@ class MHA(nn.Module):
else False,
)
def _update_kvcache_attention(self, q, kv, inference_params):
"""Write kv to inference_params, then do attention """
if (
inference_params.sequence_len_offset == 0
or flash_attn_with_kvcache is None
or not self.use_flash_attn
):
# TODO: this only uses sequence_len_offset and not lengths_per_sample.
kv = self._update_kv_cache(kv, inference_params)
return self.inner_cross_attn(q, kv)
else:
batch = q.shape[0]
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset
)
return flash_attn_with_kvcache(
q,
kv_cache[:, :, 0],
kv_cache[:, :, 1],
kv[:, :, 0],
kv[:, :, 1],
cache_seqlens=cache_seqlens,
softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal,
)
def forward(
self,
x,
......@@ -605,10 +636,19 @@ class MHA(nn.Module):
if self.use_flash_attn
else {"key_padding_mask": key_padding_mask, **kwargs}
)
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
seqlen_offset = (
0
if inference_params is None
else (
inference_params.lengths_per_sample
if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset
)
)
rotary_max_seqlen = (
inference_params.max_sequence_len if inference_params is not None else None
)
batch, seqlen = x.shape[:2]
if not self.cross_attn and self.num_heads_kv == self.num_heads:
assert x_kv is None and mixer_subset is None
if not self.return_residual:
......@@ -619,7 +659,8 @@ class MHA(nn.Module):
qkv = rearrange(
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
).contiguous()
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
# qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
qkv = qkv.reshape(batch, seqlen, 3, self.num_heads, self.head_dim)
if (
inference_params is None
or inference_params.sequence_len_offset == 0
......@@ -635,9 +676,9 @@ class MHA(nn.Module):
else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
else:
q = qkv[:, :, 0]
kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
context = self.inner_cross_attn(q, kv)
context = self._update_kvcache_attention(
qkv[:, :, 0], qkv[:, :, 1:], inference_params
)
else:
context = self._apply_rotary_single_query_attention(qkv, inference_params)
else:
......@@ -659,8 +700,10 @@ class MHA(nn.Module):
qkv, x = self.Wqkv(x)
q = qkv[..., : self.num_heads * self.head_dim]
kv = qkv[..., self.num_heads * self.head_dim :]
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
# q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
q = q.reshape(batch, seqlen, -1, self.head_dim)
# kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
kv = kv.reshape(batch, seqlen, 2, -1, self.head_dim)
if self.dwconv:
q = rearrange(
self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
......@@ -685,11 +728,11 @@ class MHA(nn.Module):
self.inner_cross_attn, q, kv, **kwargs
)
else:
kv = self._update_kv_cache(kv, inference_params)
context = self.inner_cross_attn(q, kv)
context = self._update_kvcache_attention(q, kv, inference_params)
else:
context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
# out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
out = self.out_proj(context.reshape(batch, seqlen, -1))
return out if not self.return_residual else (out, x)
......@@ -846,6 +889,36 @@ class ParallelMHA(nn.Module):
else False,
)
def _update_kvcache_attention(self, q, kv, inference_params):
"""Write kv to inference_params, then do attention """
if (
inference_params.sequence_len_offset == 0
or flash_attn_with_kvcache is None
or not self.use_flash_attn
):
# TODO: this only uses sequence_len_offset and not lengths_per_sample.
kv = self._update_kv_cache(kv, inference_params)
return self.inner_cross_attn(q, kv)
else:
batch = q.shape[0]
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset
)
context = flash_attn_with_kvcache(
q,
kv_cache[:, :, 0],
kv_cache[:, :, 1],
kv[:, :, 0],
kv[:, :, 1],
cache_seqlens=cache_seqlens,
softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal,
)
return context
def forward(self, x, seqlen=None, inference_params=None, **kwargs):
"""
Arguments:
......@@ -857,7 +930,15 @@ class ParallelMHA(nn.Module):
qkv = self.Wqkv(x)
if seqlen is not None:
qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
seqlen_offset = (
0
if inference_params is None
else (
inference_params.lengths_per_sample
if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset
)
)
rotary_max_seqlen = (
inference_params.max_sequence_len if inference_params is not None else None
)
......@@ -878,9 +959,9 @@ class ParallelMHA(nn.Module):
else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
else:
q = qkv[:, :, 0]
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
context = self.inner_cross_attn(q, kv)
context = self._update_kvcache_attention(
qkv[:, :, 0], qkv[:, :, 1:], inference_params
)
else:
context = self._apply_rotary_single_query_attention(qkv, inference_params)
else:
......@@ -912,8 +993,7 @@ class ParallelMHA(nn.Module):
self.inner_cross_attn, q, kv, **kwargs
)
else:
kv = self._update_kv_cache(kv, inference_params)
context = self.inner_cross_attn(q, kv)
context = self._update_kvcache_attention(q, kv, inference_params)
else:
context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
context = rearrange(context, "b s h d -> b s (h d)")
......
......@@ -118,7 +118,6 @@ def decode(
batch_size, seqlen_og = input_ids.shape
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
if cg:
assert fused_ft_kernel
if not hasattr(model, "_decoding_cache"):
model._decoding_cache = None
model._decoding_cache = update_graph_cache(
......@@ -128,11 +127,13 @@ def decode(
seqlen_og,
max_length,
tensor_parallel=tensor_parallel,
fused_ft_kernel=fused_ft_kernel,
)
inference_params = model._decoding_cache.inference_params
inference_params.max_sequence_len = max_length
inference_params.max_batch_size = batch_size
inference_params.sequence_len_offset = 0
inference_params.lengths_per_sample.zero_()
else:
inference_params = InferenceParams(
max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel
......@@ -167,7 +168,8 @@ def decode(
token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
else:
token = teacher_outputs[:, inference_params.sequence_len_offset]
return rearrange(token, "b -> b 1")
# return rearrange(token, "b -> b 1")
return token.unsqueeze(1)
def should_stop(current_token, inference_params):
if inference_params.sequence_len_offset == 0:
......@@ -197,9 +199,7 @@ def decode(
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
return output_cls(
sequences=torch.cat(sequences, dim=1), scores=tuple(scores)
)
return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, temperature=1.0):
......@@ -298,7 +298,6 @@ def decode_speculative(
assert batch_size == 1, "Speculative decoding implementation only supports batch_size=1"
assert eos_token_id is None, "Speculative decoding implementation doesn't support eos_token_id"
if cg:
assert fused_ft_kernel
if not hasattr(model_draft, "_decoding_cache"):
model_draft._decoding_cache = None
model_draft._decoding_cache = update_graph_cache(
......@@ -308,6 +307,7 @@ def decode_speculative(
seqlen_og,
max_length,
tensor_parallel=tensor_parallel,
fused_ft_kernel=fused_ft_kernel,
)
inference_params_draft = model_draft._decoding_cache.inference_params
inference_params_draft.max_sequence_len = max_length
......@@ -606,12 +606,14 @@ def allocate_inference_cache(
layers: Union[int, Sequence],
device,
dtype=torch.float16,
fused_ft_kernel=False,
):
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if dtype == torch.float32 else 8
assert headdim % packsize == 0
k_cache_shape = (max_batch_size, nheads, headdim // packsize, max_seqlen, packsize)
v_cache_shape = (max_batch_size, nheads, max_seqlen, headdim)
kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim)
if isinstance(layers, int):
layers = range(layers)
return {
......@@ -619,6 +621,8 @@ def allocate_inference_cache(
torch.empty(k_cache_shape, device=device, dtype=dtype),
torch.empty(v_cache_shape, device=device, dtype=dtype),
)
if fused_ft_kernel
else torch.empty(kv_cache_sahpe, device=device, dtype=dtype)
for i in layers
}
......@@ -651,7 +655,15 @@ class DecodingCGCache:
@torch.inference_mode()
def update_graph_cache(
model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1, dtype=None, n_warmups=2
model,
cache,
batch_size,
seqlen_og,
max_seqlen,
tensor_parallel=1,
dtype=None,
n_warmups=2,
fused_ft_kernel=False,
):
if cache is None:
cache = DecodingCGCache()
......@@ -671,7 +683,9 @@ def update_graph_cache(
cache.device, cache.dtype = device, dtype
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
if hasattr(model, "allocate_inference_cache"):
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
inf_cache = model.allocate_inference_cache(
batch_size, max_seqlen, dtype, fused_ft_kernel=fused_ft_kernel
)
else:
headdim = getattr(
model.config,
......@@ -686,6 +700,7 @@ def update_graph_cache(
model.config.num_hidden_layers,
device,
dtype,
fused_ft_kernel=fused_ft_kernel,
)
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
cache.inference_params = InferenceParams(
......@@ -693,7 +708,7 @@ def update_graph_cache(
max_batch_size=batch_size,
sequence_len_offset=seqlen_og,
key_value_memory_dict=inf_cache,
fused_ft_kernel=True,
fused_ft_kernel=fused_ft_kernel,
lengths_per_sample=lengths_per_sample,
)
cache.mempool = torch.cuda.graphs.graph_pool_handle()
......
......@@ -217,8 +217,9 @@ def test_baichuan_parallel_forward(model_name, world_size):
).abs().max().item()
@pytest.mark.parametrize("fused_ft_kernel", [False, True])
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"])
def test_baichuan_generation(model_name):
def test_baichuan_generation(model_name, fused_ft_kernel):
dtype = torch.float16
device = "cuda"
config = baichuan_config_to_gpt2_config(
......@@ -276,6 +277,7 @@ def test_baichuan_generation(model_name):
model.load_state_dict(pretrained_state_dict)
model.eval()
model(input_ids) # Warm up
print("Without CUDA graph")
torch.cuda.synchronize()
start = time.time()
......@@ -283,7 +285,7 @@ def test_baichuan_generation(model_name):
input_ids=input_ids,
max_length=max_length,
eos_token_id=eos_token_id,
fused_ft_kernel=True,
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
......@@ -295,7 +297,7 @@ def test_baichuan_generation(model_name):
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length
model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=fused_ft_kernel
)
print("With CUDA graph")
torch.cuda.synchronize()
......@@ -303,7 +305,7 @@ def test_baichuan_generation(model_name):
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
fused_ft_kernel=True,
fused_ft_kernel=fused_ft_kernel,
cg=True,
return_dict_in_generate=True,
output_scores=True,
......@@ -346,7 +348,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
config = baichuan_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
config.use_flash_attn = False
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = False
......@@ -393,7 +395,6 @@ def test_baichuan_parallel_generation(model_name, world_size):
max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
fused_ft_kernel=True,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate=True,
output_scores=True,
......@@ -411,7 +412,6 @@ def test_baichuan_parallel_generation(model_name, world_size):
max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
fused_ft_kernel=True,
cg=True,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate=True,
......@@ -458,6 +458,6 @@ def test_baichuan_parallel_generation(model_name, world_size):
hf_error = (logits_hf - logits_ref).abs().max().item()
print(f"HF fp16 logits max diff: {hf_error}")
print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
assert torch.equal(logits_cg, logits)
......@@ -135,7 +135,7 @@ def test_gpt2_optimized(model_name):
@pytest.mark.parametrize("fused_ft_kernel", [False, True])
# @pytest.mark.parametrize('fused_ft_kernel', [False])
# @pytest.mark.parametrize('fused_ft_kernel', [True])
@pytest.mark.parametrize("optimized", [False, True])
# @pytest.mark.parametrize('optimized', [True])
@pytest.mark.parametrize("rotary", [False, True])
......@@ -209,7 +209,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
)
print(out.sequences)
print(tokenizer.batch_decode(out.sequences.tolist()))
if fused_ft_kernel or config.use_flash_attn:
if fused_ft_kernel or getattr(config, "use_flash_attn", False):
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
......@@ -220,6 +220,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
enable_timing=True,
)
print(out_cg.sequences)
assert torch.equal(torch.stack(out.scores, dim=1), torch.stack(out_cg.scores, dim=1))
if not rotary:
out_hf = model_hf.generate(
......@@ -282,6 +283,7 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
@pytest.mark.parametrize("rotary", [None, "interleaved", "block"])
# @pytest.mark.parametrize('rotary', [None])
@pytest.mark.parametrize("fused_ft_kernel", [False, True])
# @pytest.mark.parametrize("fused_ft_kernel", [False])
@pytest.mark.parametrize("model_name", ["gpt2"])
def test_gpt2_generation_cg(model_name, fused_ft_kernel, rotary, seqlen, maxlen):
"""Check that decoding with CUDA graph is the same as decoding without CUDA graph."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment