Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
a668890f
Commit
a668890f
authored
Jan 03, 2023
by
Tri Dao
Browse files
[Gen] Add option to run generation with FT attention kernel
parent
be1afaa2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
54 additions
and
21 deletions
+54
-21
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+39
-15
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+11
-4
tests/models/test_gpt_generation.py
tests/models/test_gpt_generation.py
+4
-2
No files found.
flash_attn/modules/mha.py
View file @
a668890f
...
@@ -30,6 +30,11 @@ try:
...
@@ -30,6 +30,11 @@ try:
except
ImportError
:
except
ImportError
:
RotaryEmbedding
=
None
RotaryEmbedding
=
None
try
:
import
ft_attention
except
ImportError
:
ft_attention
=
None
class
FlashSelfAttention
(
nn
.
Module
):
class
FlashSelfAttention
(
nn
.
Module
):
"""Implement the scaled dot product attention with softmax.
"""Implement the scaled dot product attention with softmax.
...
@@ -360,23 +365,32 @@ class MHA(nn.Module):
...
@@ -360,23 +365,32 @@ class MHA(nn.Module):
assert
self
.
layer_idx
is
not
None
,
'Generation requires layer_idx in the constructor'
assert
self
.
layer_idx
is
not
None
,
'Generation requires layer_idx in the constructor'
# Pre-allocate memory for key-values for inference.
# Pre-allocate memory for key-values for inference.
if
self
.
layer_idx
not
in
inference_params
.
key_value_memory_dict
:
if
self
.
layer_idx
not
in
inference_params
.
key_value_memory_dict
:
inference_
kv_cache
=
torch
.
empty
(
kv_cache
=
torch
.
empty
(
inference_params
.
max_batch_size
,
inference_params
.
max_sequence_len
,
2
,
inference_params
.
max_batch_size
,
inference_params
.
max_sequence_len
,
2
,
self
.
num_heads
,
self
.
head_dim
,
dtype
=
kv
.
dtype
,
device
=
kv
.
device
self
.
num_heads
,
self
.
head_dim
,
dtype
=
kv
.
dtype
,
device
=
kv
.
device
)
)
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
=
inference_
kv_cache
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
=
kv_cache
else
:
else
:
inference_kv_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
assert
not
inference_params
.
fused_ft_kernel
,
'fused_ft_kernel should not take this path'
kv_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
# Adjust key and value for inference
# Adjust key and value for inference
batch_start
=
inference_params
.
batch_size_offset
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
kv
.
shape
[
0
]
batch_end
=
batch_start
+
kv
.
shape
[
0
]
assert
batch_end
<=
inference_
kv_cache
.
shape
[
0
]
assert
batch_end
<=
kv_cache
.
shape
[
0
]
sequence_start
=
inference_params
.
sequence_len_offset
sequence_start
=
inference_params
.
sequence_len_offset
sequence_end
=
sequence_start
+
kv
.
shape
[
1
]
sequence_end
=
sequence_start
+
kv
.
shape
[
1
]
assert
sequence_end
<=
inference_
kv_cache
.
shape
[
1
]
assert
sequence_end
<=
kv_cache
.
shape
[
1
]
# Copy key and values.
# Copy key and values.
inference_kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
kv_cache
[
batch_start
:
batch_end
,
sequence_start
:
sequence_end
,
...]
=
kv
kv
=
inference_kv_cache
[
batch_start
:
batch_end
,
:
sequence_end
,
...]
kv
=
kv_cache
[
batch_start
:
batch_end
,
:
sequence_end
,
...]
if
inference_params
.
fused_ft_kernel
:
# FT kernel requires different layouts for the k_cache and v_cache.
assert
kv_cache
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
packsize
=
4
if
kv_cache
.
dtype
==
torch
.
float32
else
8
k_cache
=
rearrange
(
kv_cache
[:,
:,
0
],
'b s h (d packsize) -> b h d s packsize'
,
packsize
=
packsize
).
contiguous
()
v_cache
=
rearrange
(
kv_cache
[:,
:,
1
],
'b s h d -> b h s d'
).
contiguous
()
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
=
(
k_cache
,
v_cache
)
return
kv
return
kv
def
forward
(
self
,
x
,
x_kv
=
None
,
key_padding_mask
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
def
forward
(
self
,
x
,
x_kv
=
None
,
key_padding_mask
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
...
@@ -430,14 +444,24 @@ class MHA(nn.Module):
...
@@ -430,14 +444,24 @@ class MHA(nn.Module):
else
:
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
kwargs
)
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
kwargs
)
else
:
else
:
if
self
.
rotary_emb_dim
>
0
:
if
(
not
inference_params
.
fused_ft_kernel
)
or
inference_params
.
sequence_len_offset
==
0
:
qkv
=
self
.
rotary_emb
(
qkv
,
seqlen_offset
=
inference_params
.
sequence_len_offset
)
if
self
.
rotary_emb_dim
>
0
:
q
=
qkv
[:,
:,
0
]
qkv
=
self
.
rotary_emb
(
qkv
,
seqlen_offset
=
inference_params
.
sequence_len_offset
)
kv
=
self
.
_update_kv_cache
(
qkv
[:,
:,
1
:],
inference_params
)
q
=
qkv
[:,
:,
0
]
# If we're processing the prompt, causal=None (use self.causal).
kv
=
self
.
_update_kv_cache
(
qkv
[:,
:,
1
:],
inference_params
)
# If we're decoding, then causal=False.
# If we're processing the prompt, causal=None (use self.causal).
causal
=
None
if
inference_params
.
sequence_len_offset
==
0
else
False
# If we're decoding, then causal=False.
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
causal
)
causal
=
None
if
inference_params
.
sequence_len_offset
==
0
else
False
context
=
self
.
inner_cross_attn
(
q
,
kv
,
causal
=
causal
)
else
:
assert
ft_attention
is
not
None
context
=
ft_attention
.
single_query_attention
(
*
rearrange
(
qkv
,
'b 1 three h d -> b three h d'
).
unbind
(
dim
=
1
),
*
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
],
inference_params
.
lengths_per_sample
,
inference_params
.
sequence_len_offset
,
self
.
rotary_emb_dim
)
context
=
rearrange
(
context
,
'b h d -> b 1 h d'
)
else
:
else
:
if
not
self
.
return_residual
:
if
not
self
.
return_residual
:
q
=
self
.
Wq
(
x
)
q
=
self
.
Wq
(
x
)
...
...
flash_attn/utils/generation.py
View file @
a668890f
# Copyright (c) 2022, Tri Dao.
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
from
typing
import
Optional
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
import
torch
import
torch
from
torch
import
Tensor
from
einops
import
rearrange
from
einops
import
rearrange
...
@@ -17,9 +20,11 @@ class InferenceParams:
...
@@ -17,9 +20,11 @@ class InferenceParams:
sequence_len_offset
:
int
=
0
sequence_len_offset
:
int
=
0
batch_size_offset
:
int
=
0
batch_size_offset
:
int
=
0
key_value_memory_dict
:
dict
=
field
(
default_factory
=
dict
)
key_value_memory_dict
:
dict
=
field
(
default_factory
=
dict
)
fused_ft_kernel
:
bool
=
False
lengths_per_sample
:
Optional
[
Tensor
]
=
None
def
greedy_decode
(
input_ids
,
model
,
max_length
):
def
greedy_decode
(
input_ids
,
model
,
max_length
,
fused_ft_kernel
=
True
):
"""Greedy decoding. This is a very simple implementation.
"""Greedy decoding. This is a very simple implementation.
We assume that all sequences in the same batch have the same length.
We assume that all sequences in the same batch have the same length.
Arguments:
Arguments:
...
@@ -30,7 +35,8 @@ def greedy_decode(input_ids, model, max_length):
...
@@ -30,7 +35,8 @@ def greedy_decode(input_ids, model, max_length):
scores: tuples of (batch, vocab_size)
scores: tuples of (batch, vocab_size)
"""
"""
batch_size
,
seqlen_og
=
input_ids
.
shape
batch_size
,
seqlen_og
=
input_ids
.
shape
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
)
inference_params
=
InferenceParams
(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
,
fused_ft_kernel
=
fused_ft_kernel
)
scores
=
[]
scores
=
[]
with
torch
.
inference_mode
():
with
torch
.
inference_mode
():
logits
=
model
(
input_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
logits
=
model
(
input_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
...
@@ -57,8 +63,9 @@ def greedy_decode(input_ids, model, max_length):
...
@@ -57,8 +63,9 @@ def greedy_decode(input_ids, model, max_length):
class
GenerationMixin
:
class
GenerationMixin
:
def
generate
(
self
,
input_ids
,
max_length
,
return_dict_in_generate
=
False
,
output_scores
=
False
):
def
generate
(
self
,
input_ids
,
max_length
,
return_dict_in_generate
=
False
,
output_scores
=
False
,
output
=
greedy_decode
(
input_ids
,
self
,
max_length
)
**
kwargs
):
output
=
greedy_decode
(
input_ids
,
self
,
max_length
,
**
kwargs
)
if
not
output_scores
:
if
not
output_scores
:
output
.
scores
=
None
output
.
scores
=
None
return
output
if
return_dict_in_generate
else
output
.
sequences
return
output
if
return_dict_in_generate
else
output
.
sequences
tests/models/test_gpt_generation.py
View file @
a668890f
...
@@ -15,10 +15,11 @@ from flash_attn.utils.generation import greedy_decode
...
@@ -15,10 +15,11 @@ from flash_attn.utils.generation import greedy_decode
# TODO: test with rotary embedding
# TODO: test with rotary embedding
@
pytest
.
mark
.
parametrize
(
'fused_ft_kernel'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'optimized'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'optimized'
,
[
False
,
True
])
# @pytest.mark.parametrize('optimized', [
Fals
e])
# @pytest.mark.parametrize('optimized', [
Tru
e])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
def
test_greedy_decode
(
model_name
,
optimized
):
def
test_greedy_decode
(
model_name
,
optimized
,
fused_ft_kernel
):
"""Check that our implementation of GPT2 generation matches the HF implementation:
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
the HF scores in fp32.
...
@@ -62,6 +63,7 @@ def test_greedy_decode(model_name, optimized):
...
@@ -62,6 +63,7 @@ def test_greedy_decode(model_name, optimized):
scores
=
tuple
(
scores
)
scores
=
tuple
(
scores
)
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
return_dict_in_generate
=
True
,
output_scores
=
True
)
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment