Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
0938298e
Commit
0938298e
authored
Jan 07, 2023
by
Tri Dao
Browse files
[Gen] Adjust shape of kv_cache when using FT
parent
e02fd588
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
17 deletions
+37
-17
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+34
-15
tests/models/test_gpt_generation.py
tests/models/test_gpt_generation.py
+3
-2
No files found.
flash_attn/modules/mha.py
View file @
0938298e
...
@@ -359,7 +359,7 @@ class MHA(nn.Module):
...
@@ -359,7 +359,7 @@ class MHA(nn.Module):
self
.
out_proj
=
linear_cls
(
embed_dim
,
embed_dim
,
**
factory_kwargs
)
self
.
out_proj
=
linear_cls
(
embed_dim
,
embed_dim
,
**
factory_kwargs
)
def
_update_kv_cache
(
self
,
kv
,
inference_params
):
def
_update_kv_cache
(
self
,
kv
,
inference_params
):
"""kv: (batch_size,
1
, nheads, head_dim)
"""kv: (batch_size,
seqlen, 2, nheads, head_dim) or (batch_size, 1, 2
, nheads, head_dim)
"""
"""
assert
not
self
.
dwconv
,
'Generation does not support dwconv yet'
assert
not
self
.
dwconv
,
'Generation does not support dwconv yet'
assert
self
.
layer_idx
is
not
None
,
'Generation requires layer_idx in the constructor'
assert
self
.
layer_idx
is
not
None
,
'Generation requires layer_idx in the constructor'
...
@@ -371,26 +371,45 @@ class MHA(nn.Module):
...
@@ -371,26 +371,45 @@ class MHA(nn.Module):
)
)
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
=
kv_cache
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
=
kv_cache
else
:
else
:
assert
not
inference_params
.
fused_ft_kernel
,
'fused_ft_kernel should not take this path'
if
not
inference_params
.
fused_ft_kernel
:
kv_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
kv_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
else
:
# For FT, k_cache has shape (b, h, headdim / packsize, s, packsize)
# where packsize = 4 if fp32, 8 if fp16 or bf16.
# v_cache has shape (b, h, s, headdim)
k_cache
,
v_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
kv_cache
=
None
# Adjust key and value for inference
# Adjust key and value for inference
batch_start
=
inference_params
.
batch_size_offset
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
kv
.
shape
[
0
]
batch_end
=
batch_start
+
kv
.
shape
[
0
]
assert
batch_end
<=
kv_cache
.
shape
[
0
]
sequence_start
=
inference_params
.
sequence_len_offset
sequence_start
=
inference_params
.
sequence_len_offset
sequence_end
=
sequence_start
+
kv
.
shape
[
1
]
sequence_end
=
sequence_start
+
kv
.
shape
[
1
]
assert
sequence_end
<=
kv_cache
.
shape
[
1
]
assert
batch_end
<=
(
kv_cache
.
shape
[
0
]
if
kv_cache
is
not
None
else
v_cache
.
shape
[
0
])
assert
sequence_end
<=
(
kv_cache
.
shape
[
1
]
if
kv_cache
is
not
None
else
v_cache
.
shape
[
2
])
# Copy key and values.
# Copy key and values.
if
not
inference_params
.
fused_ft_kernel
:
assert
kv_cache
is
not
None
kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
kv
=
kv_cache
[
batch_start
:
batch_end
,
:
sequence_end
,
...]
kv
=
kv_cache
[
batch_start
:
batch_end
,
:
sequence_end
,
...]
if
inference_params
.
fused_ft_kernel
:
return
kv
else
:
assert
inference_params
.
sequence_len_offset
==
0
# FT kernel requires different layouts for the k_cache and v_cache.
# FT kernel requires different layouts for the k_cache and v_cache.
assert
kv_cache
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
assert
kv
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
packsize
=
4
if
kv_cache
.
dtype
==
torch
.
float32
else
8
packsize
=
4
if
kv
.
dtype
==
torch
.
float32
else
8
if
kv_cache
is
not
None
:
kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
k_cache
=
rearrange
(
kv_cache
[:,
:,
0
],
'b s h (d packsize) -> b h d s packsize'
,
k_cache
=
rearrange
(
kv_cache
[:,
:,
0
],
'b s h (d packsize) -> b h d s packsize'
,
packsize
=
packsize
).
contiguous
()
packsize
=
packsize
).
contiguous
()
v_cache
=
rearrange
(
kv_cache
[:,
:,
1
],
'b s h d -> b h s d'
).
contiguous
()
v_cache
=
rearrange
(
kv_cache
[:,
:,
1
],
'b s h d -> b h s d'
).
contiguous
()
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
=
(
k_cache
,
v_cache
)
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
=
(
k_cache
,
v_cache
)
else
:
k_cache
[
batch_start
:
batch_end
,
:,
:,
:
sequence_end
,
:]
=
rearrange
(
kv
[:,
:,
0
],
'b s h (d packsize) -> b h d s packsize'
,
packsize
=
packsize
)
v_cache
[
batch_start
:
batch_end
,
:,
:
sequence_end
,
:]
=
rearrange
(
kv
[:,
:,
1
],
'b s h d -> b h s d'
)
return
kv
return
kv
def
forward
(
self
,
x
,
x_kv
=
None
,
key_padding_mask
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
def
forward
(
self
,
x
,
x_kv
=
None
,
key_padding_mask
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
...
...
tests/models/test_gpt_generation.py
View file @
0938298e
...
@@ -14,10 +14,11 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
...
@@ -14,10 +14,11 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
@
pytest
.
mark
.
parametrize
(
'fused_ft_kernel'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'fused_ft_kernel'
,
[
False
,
True
])
# @pytest.mark.parametrize('fused_ft_kernel', [True])
@
pytest
.
mark
.
parametrize
(
'optimized'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'optimized'
,
[
False
,
True
])
# @pytest.mark.parametrize('fused_ft_kernel', [False])
# @pytest.mark.parametrize('optimized', [False])
# @pytest.mark.parametrize('optimized', [True])
@
pytest
.
mark
.
parametrize
(
'rotary'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'rotary'
,
[
False
,
True
])
# @pytest.mark.parametrize('rotary', [False])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
def
test_greedy_decode
(
model_name
,
rotary
,
optimized
,
fused_ft_kernel
):
def
test_greedy_decode
(
model_name
,
rotary
,
optimized
,
fused_ft_kernel
):
"""Check that our implementation of GPT2 generation matches the HF implementation:
"""Check that our implementation of GPT2 generation matches the HF implementation:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment