Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
011ec323
Unverified
Commit
011ec323
authored
Aug 30, 2023
by
dan_the_3rd
Committed by
GitHub
Aug 30, 2023
Browse files
Support MQA + MP for decoding (#490)
Co-authored-by: danthe3rd <danthe3rd>
parent
0cb595ad
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
70 additions
and
39 deletions
+70
-39
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+41
-36
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+2
-2
tests/models/test_gpt.py
tests/models/test_gpt.py
+27
-1
No files found.
flash_attn/models/gpt.py
View file @
011ec323
...
...
@@ -623,22 +623,24 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
inference_params
=
None
,
num_last_tokens
=
0
):
"""
input_ids: (batch, seqlen) int tensor
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
assert
input_ids
.
ndim
==
2
,
f
"Expected `input_ids` to have shape [b, slen], but got shape
{
input_ids
.
shape
}
"
b
,
slen
=
input_ids
.
shape
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
)
if
num_last_tokens
>
0
:
hidden_states
=
hidden_states
[:,
-
num_last_tokens
:]
assert
hidden_states
.
ndim
==
3
,
"sequence_parallel is not supported in generation mode"
if
self
.
project_out
is
not
None
:
hidden_states
=
self
.
project_out
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
# During inference, we want the full logit for sampling
if
isinstance
(
self
.
lm_head
,
ColumnParallelLinear
)
and
inference_params
is
not
None
:
lm_logits
,
_
=
all_gather_raw
(
lm_logits
,
self
.
lm_head
.
process_group
)
lm_logits
=
rearrange
(
lm_logits
,
"(n b) ... d -> b ... (n d)"
,
b
=
hidden_states
.
shape
[
0
]
)
lm_logits
=
rearrange
(
lm_logits
,
"(n b) ... d -> b ... (n d)"
,
b
=
b
)
CausalLMOutput
=
namedtuple
(
"CausalLMOutput"
,
[
"logits"
])
return
CausalLMOutput
(
logits
=
lm_logits
)
...
...
@@ -802,6 +804,8 @@ def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: G
assert
config
.
hidden_size
%
world_size
==
0
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
config
.
hidden_size
assert
inner_dim
%
world_size
==
0
assert
config
.
hidden_size
%
config
.
n_head
==
0
headdim
=
config
.
hidden_size
//
config
.
n_head
# Sometimes the word embeddings are sharded on the 0th dim, sometimes on the 1st dim.
# vocab_size // world_size coordinates are nonzero.
...
...
@@ -823,14 +827,6 @@ def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: G
]
state_dict
[
key
]
=
rearrange
(
torch
.
cat
(
xs
,
dim
=
1
),
"three d ... -> (three d) ..."
)
else
:
xs
=
[
rearrange
(
s
[
key
],
"(nheadqkv headdim) ... -> nheadqkv headdim ..."
,
nheadqkv
=
n_head
+
2
*
n_head_kv
,
)
for
s
in
state_dicts
]
n_head_each_rank
=
[
get_dim_for_local_rank
(
n_head
,
world_size
,
local_rank
)
for
local_rank
in
range
(
world_size
)
...
...
@@ -839,13 +835,19 @@ def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: G
get_dim_for_local_rank
(
n_head_kv
,
world_size
,
local_rank
)
for
local_rank
in
range
(
world_size
)
]
state_dict
[
key
]
=
rearrange
(
torch
.
cat
(
[
torch
.
cat
(
xs
=
[
rearrange
(
s
[
key
],
"(nheadqkv headdim) ... -> nheadqkv headdim ..."
,
nheadqkv
=
rank_n_head
+
2
*
rank_n_head_kv
,
headdim
=
headdim
,
)
for
s
,
rank_n_head
,
rank_n_head_kv
in
zip
(
state_dicts
,
n_head_each_rank
,
n_head_kv_each_rank
)
]
wq
=
torch
.
cat
(
[
x
[:
n_head_each_rank
[
rank
]]
for
rank
,
x
in
enumerate
(
xs
)],
dim
=
0
),
torch
.
cat
(
)
wk
=
torch
.
cat
(
[
x
[
n_head_each_rank
[
rank
]
:
n_head_each_rank
[
rank
]
...
...
@@ -854,17 +856,20 @@ def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: G
for
rank
,
x
in
enumerate
(
xs
)
],
dim
=
0
,
),
torch
.
cat
(
)
wv
=
torch
.
cat
(
[
x
[
n_head_each_rank
[
rank
]
+
n_head_kv_each_rank
[
rank
]
:]
for
rank
,
x
in
enumerate
(
xs
)
],
dim
=
0
,
),
],
)
wqkv
=
torch
.
cat
(
[
wq
,
wk
,
wv
],
dim
=
0
,
),
)
state_dict
[
key
]
=
rearrange
(
wqkv
,
"nheadqkv headdim ... -> (nheadqkv headdim) ..."
,
)
...
...
flash_attn/modules/mha.py
View file @
011ec323
...
...
@@ -735,7 +735,7 @@ class ParallelMHA(nn.Module):
self
.
num_heads
,
self
.
world_size
,
self
.
local_rank
)
self
.
num_heads_kv_per_rank
=
get_dim_for_local_rank
(
self
.
num_heads
,
self
.
world_size
,
self
.
local_rank
self
.
num_heads
_kv
,
self
.
world_size
,
self
.
local_rank
)
self
.
head_dim
=
self
.
embed_dim
//
num_heads
qkv_dim
=
self
.
head_dim
*
(
self
.
num_heads
+
2
*
self
.
num_heads_kv
)
...
...
@@ -758,7 +758,7 @@ class ParallelMHA(nn.Module):
process_group
,
bias
=
qkv_proj_bias
,
sequence_parallel
=
sequence_parallel
,
multiple_of
=
self
.
head_dim
*
3
,
multiple_of
=
self
.
head_dim
*
(
self
.
num_heads_per_rank
+
2
*
self
.
num_heads_kv_per_rank
)
,
**
factory_kwargs
,
)
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
...
...
tests/models/test_gpt.py
View file @
011ec323
...
...
@@ -3,7 +3,7 @@ import re
import
pytest
import
torch
from
einops
import
rearrange
from
flash_attn.models.gpt
import
GPTLMHeadModel
,
remap_state_dict_hf_gpt2
from
flash_attn.models.gpt
import
GPTLMHeadModel
,
remap_state_dict_hf_gpt2
,
shard_state_dict_tp
,
combine_state_dicts_tp
from
flash_attn.utils.generation
import
InferenceParams
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
transformers
import
GPT2Config
,
GPT2Tokenizer
...
...
@@ -444,3 +444,29 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
return_dict_in_generate
=
True
,
)
print
(
tokenizer
.
batch_decode
(
out_og
.
sequences
))
@
pytest
.
mark
.
parametrize
(
"n_heads_q_kv"
,
[
(
8
,
8
),
# Regular attention
(
8
,
4
),
# GQA
(
8
,
2
),
# MQA
])
def
test_gpt2_shard_unshard
(
n_heads_q_kv
):
world_size
=
2
config
=
GPT2Config
.
from_pretrained
(
"gpt2"
)
config
.
vocab_size
=
1024
config
.
n_head
,
config
.
n_head_kv
=
n_heads_q_kv
model
=
GPTLMHeadModel
(
config
,
device
=
"cuda"
,
dtype
=
torch
.
float16
)
state_dict
=
model
.
state_dict
()
shards
=
[
# NOTE: Shallow copy as `state_dict` is modified in-place
shard_state_dict_tp
(
dict
(
state_dict
),
config
,
world_size
,
rank
)
for
rank
in
range
(
world_size
)
]
state_dict2
=
combine_state_dicts_tp
(
shards
,
config
)
assert
state_dict2
.
keys
()
==
state_dict
.
keys
()
for
k
in
state_dict
.
keys
():
ref
=
state_dict
[
k
]
new
=
state_dict
[
k
]
assert
torch
.
allclose
(
ref
,
new
,
atol
=
0.0
,
rtol
=
0.0
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment