Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
425dbcb6
"git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "3a87e8960b76f820d1a0d25ac356b1ec52093591"
Commit
425dbcb6
authored
Jul 23, 2023
by
Tri Dao
Browse files
[MHA] Implement MQA/GQA
parent
ec9f74ab
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
340 additions
and
153 deletions
+340
-153
flash_attn/layers/rotary.py
flash_attn/layers/rotary.py
+80
-13
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+47
-14
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+209
-121
tests/models/test_gpt_generation.py
tests/models/test_gpt_generation.py
+4
-4
tests/models/test_gptj.py
tests/models/test_gptj.py
+0
-1
No files found.
flash_attn/layers/rotary.py
View file @
425dbcb6
# Copyright (c) 2023, Tri Dao.
# Copyright (c) 2023, Tri Dao.
from
typing
import
Tuple
from
typing
import
Tuple
,
Optional
import
math
import
math
import
torch
import
torch
...
@@ -151,6 +151,51 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
...
@@ -151,6 +151,51 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
apply_rotary_emb_qkv_
=
ApplyRotaryEmbQKV_
.
apply
apply_rotary_emb_qkv_
=
ApplyRotaryEmbQKV_
.
apply
class
ApplyRotaryEmbKV_
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
kv
,
cos
,
sin
,
interleaved
=
False
):
"""
kv: (batch_size, seqlen, 2, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of k.
"""
batch
,
seqlen
,
two
,
nheads
,
headdim
=
kv
.
shape
assert
two
==
2
rotary_seqlen
,
rotary_dim
=
cos
.
shape
rotary_dim
*=
2
assert
rotary_dim
<=
headdim
assert
seqlen
<=
rotary_seqlen
k_ro
=
kv
[:,
:,
0
,
:,
:
rotary_dim
]
k1
,
k2
=
k_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
k_ro
[...,
::
2
],
k_ro
[...,
1
::
2
])
rotary_emb
.
apply_rotary
(
k1
,
k2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
k1
,
k2
,
False
)
# conj=False since this is the forward pass
ctx
.
save_for_backward
(
cos
,
sin
)
ctx
.
interleaved
=
interleaved
return
kv
@
staticmethod
def
backward
(
ctx
,
dkv
):
cos
,
sin
=
ctx
.
saved_tensors
_
,
seqlen
,
_
,
_
,
headdim
=
dkv
.
shape
rotary_dim
=
cos
.
shape
[
-
1
]
rotary_dim
*=
2
dk_ro
=
dkv
[:,
:,
0
,
:,
:
rotary_dim
]
dk1
,
dk2
=
(
dk_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
dk_ro
[...,
::
2
],
dk_ro
[...,
1
::
2
]))
rotary_emb
.
apply_rotary
(
dk1
,
dk2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
dk1
,
dk2
,
True
)
# conj=True since this is the backward pass
return
dkv
,
None
,
None
,
None
apply_rotary_emb_kv_
=
ApplyRotaryEmbKV_
.
apply
class
RotaryEmbedding
(
torch
.
nn
.
Module
):
class
RotaryEmbedding
(
torch
.
nn
.
Module
):
"""
"""
The rotary position embeddings from RoFormer_ (Su et. al).
The rotary position embeddings from RoFormer_ (Su et. al).
...
@@ -249,21 +294,43 @@ class RotaryEmbedding(torch.nn.Module):
...
@@ -249,21 +294,43 @@ class RotaryEmbedding(torch.nn.Module):
self
.
_cos_k_cached
=
(
torch
.
cos
(
freqs
)
/
scale
).
to
(
dtype
)
self
.
_cos_k_cached
=
(
torch
.
cos
(
freqs
)
/
scale
).
to
(
dtype
)
self
.
_sin_k_cached
=
(
torch
.
sin
(
freqs
)
/
scale
).
to
(
dtype
)
self
.
_sin_k_cached
=
(
torch
.
sin
(
freqs
)
/
scale
).
to
(
dtype
)
def
forward
(
self
,
qkv
:
torch
.
Tensor
,
seqlen_offset
:
int
=
0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
forward
(
self
,
qkv
:
torch
.
Tensor
,
kv
:
Optional
[
torch
.
Tensor
]
=
None
,
seqlen_offset
:
int
=
0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
qkv: (batch, seqlen, 3, nheads, headdim)
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
else it's just q of shape (batch, seqlen, nheads, headdim)
kv: (batch, seqlen, 2, nheads, headdim)
seqlen_offset: can be used in generation where the qkv being passed in is only the last
seqlen_offset: can be used in generation where the qkv being passed in is only the last
token in the batch.
token in the batch.
"""
"""
self
.
_update_cos_sin_cache
(
qkv
.
shape
[
1
]
+
seqlen_offset
,
device
=
qkv
.
device
,
dtype
=
qkv
.
dtype
)
seqlen
=
qkv
.
shape
[
1
]
if
self
.
scale
is
None
:
self
.
_update_cos_sin_cache
(
seqlen
+
seqlen_offset
,
device
=
qkv
.
device
,
dtype
=
qkv
.
dtype
)
return
apply_rotary_emb_qkv_
(
if
kv
is
None
:
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
if
self
.
scale
is
None
:
None
,
None
,
self
.
interleaved
return
apply_rotary_emb_qkv_
(
)
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
None
,
None
,
self
.
interleaved
)
else
:
return
apply_rotary_emb_qkv_
(
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
self
.
_cos_k_cached
[
seqlen_offset
:],
self
.
_sin_k_cached
[
seqlen_offset
:],
self
.
interleaved
)
else
:
else
:
return
apply_rotary_emb_
qkv
_
(
q
=
qkv
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
q
=
apply_rotary_emb_func
(
self
.
_cos_
k_
cached
[
seqlen_offset
:],
self
.
_sin_
k_
cached
[
seqlen_offset
:],
q
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
self
.
interleaved
self
.
interleaved
,
True
)
)
if
self
.
scale
is
None
:
kv
=
apply_rotary_emb_kv_
(
kv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
self
.
interleaved
)
else
:
kv
=
apply_rotary_emb_kv_
(
kv
,
self
.
_cos_k_cached
[
seqlen_offset
:],
self
.
_sin_k_cached
[
seqlen_offset
:],
self
.
interleaved
)
return
q
,
kv
flash_attn/models/gpt.py
View file @
425dbcb6
...
@@ -88,7 +88,9 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
...
@@ -88,7 +88,9 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
parallel_kwargs
=
({
'process_group'
:
process_group
,
parallel_kwargs
=
({
'process_group'
:
process_group
,
'sequence_parallel'
:
getattr
(
config
,
'sequence_parallel'
,
True
)}
'sequence_parallel'
:
getattr
(
config
,
'sequence_parallel'
,
True
)}
if
process_group
is
not
None
else
{})
if
process_group
is
not
None
else
{})
num_heads_kv
=
getattr
(
config
,
"n_head_kv"
,
None
)
mixer_cls
=
partial
(
mha_cls
,
num_heads
=
config
.
num_attention_heads
,
mixer_cls
=
partial
(
mha_cls
,
num_heads
=
config
.
num_attention_heads
,
num_heads_kv
=
num_heads_kv
,
qkv_proj_bias
=
qkv_proj_bias
,
out_proj_bias
=
out_proj_bias
,
qkv_proj_bias
=
qkv_proj_bias
,
out_proj_bias
=
out_proj_bias
,
dropout
=
config
.
attn_pdrop
,
dropout
=
config
.
attn_pdrop
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
layer_idx
=
layer_idx
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
layer_idx
=
layer_idx
,
...
@@ -503,20 +505,37 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
...
@@ -503,20 +505,37 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
assert
inner_dim
%
world_size
==
0
assert
inner_dim
%
world_size
==
0
def
shard_first_dim
(
state_dict
,
key
):
def
shard_first_dim
(
state_dict
,
key
):
x
=
state_dict
[
key
]
if
key
in
state_dict
:
dim
=
x
.
shape
[
0
]
//
world_size
x
=
state_dict
[
key
]
state_dict
[
key
]
=
x
[
rank
*
dim
:(
rank
+
1
)
*
dim
]
dim
=
x
.
shape
[
0
]
//
world_size
state_dict
[
key
]
=
x
[
rank
*
dim
:(
rank
+
1
)
*
dim
]
def
shard_last_dim
(
state_dict
,
key
):
def
shard_last_dim
(
state_dict
,
key
):
x
=
state_dict
[
key
]
if
key
in
state_dict
:
dim
=
x
.
shape
[
-
1
]
//
world_size
x
=
state_dict
[
key
]
state_dict
[
key
]
=
x
[...,
rank
*
dim
:(
rank
+
1
)
*
dim
]
dim
=
x
.
shape
[
-
1
]
//
world_size
state_dict
[
key
]
=
x
[...,
rank
*
dim
:(
rank
+
1
)
*
dim
]
def
shard_qkv_headdim
(
state_dict
,
key
):
def
shard_qkv_headdim
(
state_dict
,
key
):
x
=
rearrange
(
state_dict
[
key
],
'(three d) ... -> three d ...'
,
three
=
3
)
if
key
in
state_dict
:
dim
=
x
.
shape
[
1
]
//
world_size
n_head
=
config
.
n_head
state_dict
[
key
]
=
rearrange
(
x
[:,
rank
*
dim
:(
rank
+
1
)
*
dim
],
n_head_kv
=
getattr
(
config
,
'n_head_kv'
,
n_head
)
'three d ... -> (three d) ...'
)
assert
n_head
%
world_size
==
0
and
n_head_kv
%
world_size
==
0
if
n_head_kv
==
n_head
:
x
=
rearrange
(
state_dict
[
key
],
'(three d) ... -> three d ...'
,
three
=
3
)
dim
=
x
.
shape
[
1
]
//
world_size
state_dict
[
key
]
=
rearrange
(
x
[:,
rank
*
dim
:(
rank
+
1
)
*
dim
],
'three d ... -> (three d) ...'
)
else
:
n_head_per_rank
=
n_head
//
world_size
n_head_kv_per_rank
=
n_head_kv
//
world_size
x
=
rearrange
(
state_dict
[
key
],
'(nheadqkv headdim) ... -> nheadqkv headdim ...'
,
nheadqkv
=
n_head
+
2
*
n_head_kv
)
state_dict
[
key
]
=
rearrange
(
torch
.
cat
([
x
[
rank
*
n_head_per_rank
:(
rank
+
1
)
*
n_head_per_rank
],
x
[
n_head
+
rank
*
n_head_kv_per_rank
:
n_head
+
(
rank
+
1
)
*
n_head_kv_per_rank
],
x
[
n_head
+
n_head_kv
+
rank
*
n_head_kv_per_rank
:
n_head
+
n_head_kv
+
(
rank
+
1
)
*
n_head_kv_per_rank
],
],
dim
=
0
),
"nheadqkv headdim ... -> (nheadqkv headdim) ..."
)
shard_first_dim
(
state_dict
,
'transformer.embeddings.word_embeddings.weight'
)
shard_first_dim
(
state_dict
,
'transformer.embeddings.word_embeddings.weight'
)
if
'lm_head.weight'
in
state_dict
:
if
'lm_head.weight'
in
state_dict
:
...
@@ -528,12 +547,12 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
...
@@ -528,12 +547,12 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
shard_qkv_headdim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mixer.Wqkv.bias'
)
shard_qkv_headdim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mixer.Wqkv.bias'
)
shard_last_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mixer.out_proj.weight'
)
shard_last_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mixer.out_proj.weight'
)
if
rank
!=
0
:
if
rank
!=
0
:
state_dict
.
pop
(
f
'transformer.layers.
{
i
}
.mixer.out_proj.bias'
)
state_dict
.
pop
(
f
'transformer.layers.
{
i
}
.mixer.out_proj.bias'
,
None
)
shard_first_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc1.weight'
)
shard_first_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc1.weight'
)
shard_first_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc1.bias'
)
shard_first_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc1.bias'
)
shard_last_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc2.weight'
)
shard_last_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc2.weight'
)
if
rank
!=
0
:
if
rank
!=
0
:
state_dict
.
pop
(
f
'transformer.layers.
{
i
}
.mlp.fc2.bias'
)
state_dict
.
pop
(
f
'transformer.layers.
{
i
}
.mlp.fc2.bias'
,
None
)
return
state_dict
return
state_dict
...
@@ -561,9 +580,23 @@ def combine_state_dicts_tp(state_dicts, config):
...
@@ -561,9 +580,23 @@ def combine_state_dicts_tp(state_dicts, config):
state_dict
[
key
]
=
torch
.
cat
([
s
[
key
]
for
s
in
state_dicts
],
dim
=
dim
)
state_dict
[
key
]
=
torch
.
cat
([
s
[
key
]
for
s
in
state_dicts
],
dim
=
dim
)
def
combine_qkv_headdim
(
state_dicts
,
state_dict
,
key
):
def
combine_qkv_headdim
(
state_dicts
,
state_dict
,
key
):
n_head
=
config
.
n_head
n_head_kv
=
getattr
(
config
,
'n_head_kv'
,
n_head
)
assert
n_head
%
world_size
==
0
and
n_head_kv
%
world_size
==
0
n_head_per_rank
=
n_head
//
world_size
n_head_kv_per_rank
=
n_head_kv
//
world_size
if
key
in
state_dict
:
if
key
in
state_dict
:
xs
=
[
rearrange
(
s
[
key
],
'(three d) ... -> three d ...'
,
three
=
3
)
for
s
in
state_dicts
]
if
n_head_kv
==
n_head
:
state_dict
[
key
]
=
rearrange
(
torch
.
cat
(
xs
,
dim
=
1
),
'three d ... -> (three d) ...'
)
xs
=
[
rearrange
(
s
[
key
],
'(three d) ... -> three d ...'
,
three
=
3
)
for
s
in
state_dicts
]
state_dict
[
key
]
=
rearrange
(
torch
.
cat
(
xs
,
dim
=
1
),
'three d ... -> (three d) ...'
)
else
:
xs
=
[
rearrange
(
s
[
key
],
'(nheadqkv headdim) ... -> nheadqkv headdim ...'
,
nheadqkv
=
n_head
+
2
*
n_head_kv
)
for
s
in
state_dicts
]
state_dict
[
key
]
=
rearrange
(
torch
.
cat
([
torch
.
cat
([
x
[:
n_head_per_rank
]
for
x
in
xs
],
dim
=
0
),
torch
.
cat
([
x
[
n_head_per_rank
:
n_head_per_rank
+
n_head_kv_per_rank
]
for
x
in
xs
],
dim
=
0
),
torch
.
cat
([
x
[
-
n_head_kv_per_rank
:]
for
x
in
xs
],
dim
=
0
),
],
dim
=
0
),
"nheadqkv headdim ... -> (nheadqkv headdim) ..."
)
def
combine_gated_mlp
(
state_dicts
,
state_dict
,
key
):
def
combine_gated_mlp
(
state_dicts
,
state_dict
,
key
):
if
key
in
state_dict
:
if
key
in
state_dict
:
...
...
flash_attn/modules/mha.py
View file @
425dbcb6
This diff is collapsed.
Click to expand it.
tests/models/test_gpt_generation.py
View file @
425dbcb6
...
@@ -60,9 +60,9 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
...
@@ -60,9 +60,9 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
input_ids
=
tokenizer
(
"Hello, my dog is cute and"
,
input_ids
=
tokenizer
(
"Hello, my dog is cute and
he
"
,
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
max_length
=
30
max_length
=
25
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
# max_length = input_ids.shape[1] + 40
...
@@ -143,9 +143,9 @@ def test_greedy_decode_opt(model_name):
...
@@ -143,9 +143,9 @@ def test_greedy_decode_opt(model_name):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
use_fast
=
False
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
use_fast
=
False
)
eos_token_id
=
tokenizer
.
eos_token_id
eos_token_id
=
tokenizer
.
eos_token_id
input_ids
=
tokenizer
(
"Hello, my dog is cute and"
,
input_ids
=
tokenizer
(
"Hello, my dog is cute and
he
"
,
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
max_length
=
60
max_length
=
25
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
# max_length = input_ids.shape[1] + 40
...
...
tests/models/test_gptj.py
View file @
425dbcb6
...
@@ -48,7 +48,6 @@ def test_gptj_optimized(model_name):
...
@@ -48,7 +48,6 @@ def test_gptj_optimized(model_name):
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
batch_size
=
2
batch_size
=
2
max_seqlen
=
256
max_seqlen
=
256
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
device
)
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
device
=
device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment