Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
a668890f
"comfy/ldm/models/vscode:/vscode.git/clone" did not exist on "42152062810258f75e2b528d509291c366fcfc92"
Commit
a668890f
authored
Jan 03, 2023
by
Tri Dao
Browse files
[Gen] Add option to run generation with FT attention kernel
parent
be1afaa2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
54 additions
and
21 deletions
+54
-21
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+39
-15
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+11
-4
tests/models/test_gpt_generation.py
tests/models/test_gpt_generation.py
+4
-2
No files found.
flash_attn/modules/mha.py
View file @
a668890f
...
...
@@ -30,6 +30,11 @@ try:
except
ImportError
:
RotaryEmbedding
=
None
try
:
import
ft_attention
except
ImportError
:
ft_attention
=
None
class
FlashSelfAttention
(
nn
.
Module
):
"""Implement the scaled dot product attention with softmax.
...
...
@@ -360,23 +365,32 @@ class MHA(nn.Module):
assert
self
.
layer_idx
is
not
None
,
'Generation requires layer_idx in the constructor'
# Pre-allocate memory for key-values for inference.
if
self
.
layer_idx
not
in
inference_params
.
key_value_memory_dict
:
inference_
kv_cache
=
torch
.
empty
(
kv_cache
=
torch
.
empty
(
inference_params
.
max_batch_size
,
inference_params
.
max_sequence_len
,
2
,
self
.
num_heads
,
self
.
head_dim
,
dtype
=
kv
.
dtype
,
device
=
kv
.
device
)
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
=
inference_
kv_cache
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
=
kv_cache
else
:
inference_kv_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
assert
not
inference_params
.
fused_ft_kernel
,
'fused_ft_kernel should not take this path'
kv_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
# Adjust key and value for inference
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
kv
.
shape
[
0
]
assert
batch_end
<=
inference_
kv_cache
.
shape
[
0
]
assert
batch_end
<=
kv_cache
.
shape
[
0
]
sequence_start
=
inference_params
.
sequence_len_offset
sequence_end
=
sequence_start
+
kv
.
shape
[
1
]
assert
sequence_end
<=
inference_
kv_cache
.
shape
[
1
]
assert
sequence_end
<=
kv_cache
.
shape
[
1
]
# Copy key and values.
inference_kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
kv
=
inference_kv_cache
[
batch_start
:
batch_end
,
:
sequence_end
,
...]
kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
kv
=
kv_cache
[
batch_start
:
batch_end
,
:
sequence_end
,
...]
if
inference_params
.
fused_ft_kernel
:
# FT kernel requires different layouts for the k_cache and v_cache.
assert
kv_cache
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
packsize
=
4
if
kv_cache
.
dtype
==
torch
.
float32
else
8
k_cache
=
rearrange
(
kv_cache
[:,
:,
0
],
'b s h (d packsize) -> b h d s packsize'
,
packsize
=
packsize
).
contiguous
()
v_cache
=
rearrange
(
kv_cache
[:,
:,
1
],
'b s h d -> b h s d'
).
contiguous
()
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
=
(
k_cache
,
v_cache
)
return
kv
def
forward
(
self
,
x
,
x_kv
=
None
,
key_padding_mask
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
...
...
@@ -430,14 +444,24 @@ class MHA(nn.Module):
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
kwargs
)
else
:
if
self
.
rotary_emb_dim
>
0
:
qkv
=
self
.
rotary_emb
(
qkv
,
seqlen_offset
=
inference_params
.
sequence_len_offset
)
q
=
qkv
[:,
:,
0
]
kv
=
self
.
_update_kv_cache
(
qkv
[:,
:,
1
:],
inference_params
)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal
=
None
if
inference_params
.
sequence_len_offset
==
0
else
False
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
causal
)
if
(
not
inference_params
.
fused_ft_kernel
)
or
inference_params
.
sequence_len_offset
==
0
:
if
self
.
rotary_emb_dim
>
0
:
qkv
=
self
.
rotary_emb
(
qkv
,
seqlen_offset
=
inference_params
.
sequence_len_offset
)
q
=
qkv
[:,
:,
0
]
kv
=
self
.
_update_kv_cache
(
qkv
[:,
:,
1
:],
inference_params
)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal
=
None
if
inference_params
.
sequence_len_offset
==
0
else
False
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
causal
)
else
:
assert
ft_attention
is
not
None
context
=
ft_attention
.
single_query_attention
(
*
rearrange
(
qkv
,
'b 1 three h d -> b three h d'
).
unbind
(
dim
=
1
),
*
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
],
inference_params
.
lengths_per_sample
,
inference_params
.
sequence_len_offset
,
self
.
rotary_emb_dim
)
context
=
rearrange
(
context
,
'b h d -> b 1 h d'
)
else
:
if
not
self
.
return_residual
:
q
=
self
.
Wq
(
x
)
...
...
flash_attn/utils/generation.py
View file @
a668890f
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
from
typing
import
Optional
from
dataclasses
import
dataclass
,
field
import
torch
from
torch
import
Tensor
from
einops
import
rearrange
...
...
@@ -17,9 +20,11 @@ class InferenceParams:
sequence_len_offset
:
int
=
0
batch_size_offset
:
int
=
0
key_value_memory_dict
:
dict
=
field
(
default_factory
=
dict
)
fused_ft_kernel
:
bool
=
False
lengths_per_sample
:
Optional
[
Tensor
]
=
None
def
greedy_decode
(
input_ids
,
model
,
max_length
):
def
greedy_decode
(
input_ids
,
model
,
max_length
,
fused_ft_kernel
=
True
):
"""Greedy decoding. This is a very simple implementation.
We assume that all sequences in the same batch have the same length.
Arguments:
...
...
@@ -30,7 +35,8 @@ def greedy_decode(input_ids, model, max_length):
scores: tuples of (batch, vocab_size)
"""
batch_size
,
seqlen_og
=
input_ids
.
shape
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
)
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
,
fused_ft_kernel
=
fused_ft_kernel
)
scores
=
[]
with
torch
.
inference_mode
():
logits
=
model
(
input_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
...
...
@@ -57,8 +63,9 @@ def greedy_decode(input_ids, model, max_length):
class
GenerationMixin
:
def
generate
(
self
,
input_ids
,
max_length
,
return_dict_in_generate
=
False
,
output_scores
=
False
):
output
=
greedy_decode
(
input_ids
,
self
,
max_length
)
def
generate
(
self
,
input_ids
,
max_length
,
return_dict_in_generate
=
False
,
output_scores
=
False
,
**
kwargs
):
output
=
greedy_decode
(
input_ids
,
self
,
max_length
,
**
kwargs
)
if
not
output_scores
:
output
.
scores
=
None
return
output
if
return_dict_in_generate
else
output
.
sequences
tests/models/test_gpt_generation.py
View file @
a668890f
...
...
@@ -15,10 +15,11 @@ from flash_attn.utils.generation import greedy_decode
# TODO: test with rotary embedding
@
pytest
.
mark
.
parametrize
(
'fused_ft_kernel'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'optimized'
,
[
False
,
True
])
# @pytest.mark.parametrize('optimized', [
Fals
e])
# @pytest.mark.parametrize('optimized', [
Tru
e])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
def
test_greedy_decode
(
model_name
,
optimized
):
def
test_greedy_decode
(
model_name
,
optimized
,
fused_ft_kernel
):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
...
...
@@ -62,6 +63,7 @@ def test_greedy_decode(model_name, optimized):
scores
=
tuple
(
scores
)
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment