"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "39b8e4430ebe84c409163b9970145b0d7d53a36c"
Commit 0938298e authored by Tri Dao's avatar Tri Dao
Browse files

[Gen] Adjust shape of kv_cache when using FT

parent e02fd588
...@@ -359,7 +359,7 @@ class MHA(nn.Module): ...@@ -359,7 +359,7 @@ class MHA(nn.Module):
self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs) self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs)
def _update_kv_cache(self, kv, inference_params): def _update_kv_cache(self, kv, inference_params):
"""kv: (batch_size, 1, nheads, head_dim) """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
""" """
assert not self.dwconv, 'Generation does not support dwconv yet' assert not self.dwconv, 'Generation does not support dwconv yet'
assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor' assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor'
...@@ -371,27 +371,46 @@ class MHA(nn.Module): ...@@ -371,27 +371,46 @@ class MHA(nn.Module):
) )
inference_params.key_value_memory_dict[self.layer_idx] = kv_cache inference_params.key_value_memory_dict[self.layer_idx] = kv_cache
else: else:
assert not inference_params.fused_ft_kernel, 'fused_ft_kernel should not take this path' if not inference_params.fused_ft_kernel:
kv_cache = inference_params.key_value_memory_dict[self.layer_idx] kv_cache = inference_params.key_value_memory_dict[self.layer_idx]
else:
# For FT, k_cache has shape (b, h, headdim / packsize, s, packsize)
# where packsize = 4 if fp32, 8 if fp16 or bf16.
# v_cache has shape (b, h, s, headdim)
k_cache, v_cache = inference_params.key_value_memory_dict[self.layer_idx]
kv_cache = None
# Adjust key and value for inference # Adjust key and value for inference
batch_start = inference_params.batch_size_offset batch_start = inference_params.batch_size_offset
batch_end = batch_start + kv.shape[0] batch_end = batch_start + kv.shape[0]
assert batch_end <= kv_cache.shape[0]
sequence_start = inference_params.sequence_len_offset sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + kv.shape[1] sequence_end = sequence_start + kv.shape[1]
assert sequence_end <= kv_cache.shape[1] assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
# Copy key and values. # Copy key and values.
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv if not inference_params.fused_ft_kernel:
kv = kv_cache[batch_start:batch_end, :sequence_end, ...] assert kv_cache is not None
if inference_params.fused_ft_kernel: kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
return kv
else:
assert inference_params.sequence_len_offset == 0
# FT kernel requires different layouts for the k_cache and v_cache. # FT kernel requires different layouts for the k_cache and v_cache.
assert kv_cache.dtype in [torch.float16, torch.bfloat16, torch.float32] assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if kv_cache.dtype == torch.float32 else 8 packsize = 4 if kv.dtype == torch.float32 else 8
k_cache = rearrange(kv_cache[:, :, 0], 'b s h (d packsize) -> b h d s packsize', if kv_cache is not None:
packsize=packsize).contiguous() kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
v_cache = rearrange(kv_cache[:, :, 1], 'b s h d -> b h s d').contiguous() k_cache = rearrange(kv_cache[:, :, 0], 'b s h (d packsize) -> b h d s packsize',
inference_params.key_value_memory_dict[self.layer_idx] = (k_cache, v_cache) packsize=packsize).contiguous()
return kv v_cache = rearrange(kv_cache[:, :, 1], 'b s h d -> b h s d').contiguous()
inference_params.key_value_memory_dict[self.layer_idx] = (k_cache, v_cache)
else:
k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
kv[:, :, 0], 'b s h (d packsize) -> b h d s packsize', packsize=packsize
)
v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(
kv[:, :, 1], 'b s h d -> b h s d'
)
return kv
def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None, def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None,
inference_params=None, **kwargs): inference_params=None, **kwargs):
......
...@@ -14,10 +14,11 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained ...@@ -14,10 +14,11 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
@pytest.mark.parametrize('fused_ft_kernel', [False, True]) @pytest.mark.parametrize('fused_ft_kernel', [False, True])
# @pytest.mark.parametrize('fused_ft_kernel', [True])
@pytest.mark.parametrize('optimized', [False, True]) @pytest.mark.parametrize('optimized', [False, True])
# @pytest.mark.parametrize('fused_ft_kernel', [False]) # @pytest.mark.parametrize('optimized', [False])
# @pytest.mark.parametrize('optimized', [True])
@pytest.mark.parametrize('rotary', [False, True]) @pytest.mark.parametrize('rotary', [False, True])
# @pytest.mark.parametrize('rotary', [False])
@pytest.mark.parametrize('model_name', ["gpt2"]) @pytest.mark.parametrize('model_name', ["gpt2"])
def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel): def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
"""Check that our implementation of GPT2 generation matches the HF implementation: """Check that our implementation of GPT2 generation matches the HF implementation:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment