Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
ef1ba918
Commit
ef1ba918
authored
Jan 01, 2023
by
Tri Dao
Browse files
[GPT] Refactor function to shard state_dict for TensorParallel
parent
65b4064b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
87 additions
and
63 deletions
+87
-63
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+50
-0
tests/models/test_gpt_parallel.py
tests/models/test_gpt_parallel.py
+37
-63
No files found.
flash_attn/models/gpt.py
View file @
ef1ba918
...
...
@@ -14,6 +14,8 @@ import torch.nn.functional as F
from
transformers
import
GPT2Config
from
einops
import
rearrange
from
flash_attn.modules.mha
import
MHA
,
ParallelMHA
from
flash_attn.modules.mlp
import
Mlp
,
FusedDenseGeluDense
,
ParallelFusedDenseGeluDense
from
flash_attn.modules.block
import
Block
...
...
@@ -338,3 +340,51 @@ def remap_state_dict_gpt2(state_dict, config):
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
def
shard_state_dict_tp
(
state_dict
,
config
,
world_size
,
rank
):
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
with tensor parallel.
"""
vocab_size
=
config
.
vocab_size
if
config
.
vocab_size
%
config
.
pad_vocab_size_multiple
!=
0
:
vocab_size
+=
(
config
.
pad_vocab_size_multiple
-
(
config
.
vocab_size
%
config
.
pad_vocab_size_multiple
))
assert
vocab_size
%
world_size
==
0
assert
config
.
hidden_size
%
world_size
==
0
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
config
.
hidden_size
assert
inner_dim
%
world_size
==
0
def
shard_first_dim
(
state_dict
,
key
):
x
=
state_dict
[
key
]
dim
=
x
.
shape
[
0
]
//
world_size
state_dict
[
key
]
=
x
[
rank
*
dim
:(
rank
+
1
)
*
dim
]
def
shard_last_dim
(
state_dict
,
key
):
x
=
state_dict
[
key
]
dim
=
x
.
shape
[
-
1
]
//
world_size
state_dict
[
key
]
=
x
[...,
rank
*
dim
:(
rank
+
1
)
*
dim
]
def
shard_qkv_headdim
(
state_dict
,
key
):
x
=
rearrange
(
state_dict
[
key
],
'(three d) ... -> three d ...'
,
three
=
3
)
dim
=
x
.
shape
[
1
]
//
world_size
state_dict
[
key
]
=
rearrange
(
x
[:,
rank
*
dim
:(
rank
+
1
)
*
dim
],
'three d ... -> (three d) ...'
)
shard_first_dim
(
state_dict
,
'transformer.embeddings.word_embeddings.weight'
)
if
'lm_head.weight'
in
state_dict
:
shard_first_dim
(
state_dict
,
'lm_head.weight'
)
if
'transformer.embeddings.position_embeddings.weight'
in
state_dict
:
shard_last_dim
(
state_dict
,
'transformer.embeddings.position_embeddings.weight'
)
for
i
in
range
(
config
.
num_hidden_layers
):
shard_qkv_headdim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mixer.Wqkv.weight'
)
shard_qkv_headdim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mixer.Wqkv.bias'
)
shard_last_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mixer.out_proj.weight'
)
if
rank
!=
0
:
state_dict
.
pop
(
f
'transformer.layers.
{
i
}
.mixer.out_proj.bias'
)
shard_first_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc1.weight'
)
shard_first_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc1.bias'
)
shard_last_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc2.weight'
)
if
rank
!=
0
:
state_dict
.
pop
(
f
'transformer.layers.
{
i
}
.mlp.fc2.bias'
)
return
state_dict
tests/models/test_gpt_parallel.py
View file @
ef1ba918
...
...
@@ -12,7 +12,7 @@ from transformers import GPT2Config
from
apex.transformer
import
parallel_state
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt
import
GPTLMHeadModel
,
shard_state_dict_tp
from
flash_attn.losses.cross_entropy
import
CrossEntropyLoss
from
flash_attn.utils.distributed
import
allreduce_sequence_parallel_grad
...
...
@@ -22,11 +22,11 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize('world_size', [
1
])
# @pytest.mark.parametrize('world_size', [
2
])
@
pytest
.
mark
.
parametrize
(
'has_pos_emb'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_pos_emb', [True])
@
pytest
.
mark
.
parametrize
(
'dim'
,
[
1024
])
def
test_
block
_parallel
(
dim
,
has_pos_emb
,
world_size
,
dtype
):
def
test_
gpt
_parallel
(
dim
,
has_pos_emb
,
world_size
,
dtype
):
head_dim
=
64
assert
dim
%
head_dim
==
0
num_heads
=
dim
//
head_dim
...
...
@@ -91,45 +91,8 @@ def test_block_parallel(dim, has_pos_emb, world_size, dtype):
partition_dim
=
dim
//
world_size
partition_hidden_dim
=
4
*
dim
//
world_size
with
torch
.
no_grad
():
model
.
transformer
.
embeddings
.
word_embeddings
.
weight
.
copy_
(
model_pt
.
transformer
.
embeddings
.
word_embeddings
.
weight
[
rank
*
partition_vocab_size
:(
rank
+
1
)
*
partition_vocab_size
]
)
if
has_pos_emb
:
model
.
transformer
.
embeddings
.
position_embeddings
.
weight
.
copy_
(
model_pt
.
transformer
.
embeddings
.
position_embeddings
.
weight
[:,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
]
)
model
.
transformer
.
ln_0
.
weight
.
copy_
(
model_pt
.
transformer
.
ln_0
.
weight
)
model
.
transformer
.
ln_0
.
bias
.
copy_
(
model_pt
.
transformer
.
ln_0
.
bias
)
for
i
in
range
(
num_layers
):
model
.
transformer
.
layers
[
i
].
mixer
.
Wqkv
.
weight
.
copy_
(
rearrange
(
rearrange
(
model_pt
.
transformer
.
layers
[
i
].
mixer
.
Wqkv
.
weight
,
'(three o) i -> three o i'
,
three
=
3
)[:,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
],
'three o i -> (three o) i'
)
)
model
.
transformer
.
layers
[
i
].
mixer
.
Wqkv
.
bias
.
copy_
(
rearrange
(
rearrange
(
model_pt
.
transformer
.
layers
[
i
].
mixer
.
Wqkv
.
bias
,
'(three o) -> three o'
,
three
=
3
)[:,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
],
'three o -> (three o)'
)
)
model
.
transformer
.
layers
[
i
].
mixer
.
out_proj
.
weight
.
copy_
(
model_pt
.
transformer
.
layers
[
i
].
mixer
.
out_proj
.
weight
[:,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
]
)
if
rank
==
0
:
model
.
transformer
.
layers
[
i
].
mixer
.
out_proj
.
bias
.
copy_
(
model_pt
.
transformer
.
layers
[
i
].
mixer
.
out_proj
.
bias
)
model
.
transformer
.
layers
[
i
].
mlp
.
fc1
.
weight
.
copy_
(
model_pt
.
transformer
.
layers
[
i
].
mlp
.
fc1
.
weight
[
rank
*
partition_hidden_dim
:(
rank
+
1
)
*
partition_hidden_dim
]
)
model
.
transformer
.
layers
[
i
].
mlp
.
fc1
.
bias
.
copy_
(
model_pt
.
transformer
.
layers
[
i
].
mlp
.
fc1
.
bias
[
rank
*
partition_hidden_dim
:(
rank
+
1
)
*
partition_hidden_dim
]
)
model
.
transformer
.
layers
[
i
].
mlp
.
fc2
.
weight
.
copy_
(
model_pt
.
transformer
.
layers
[
i
].
mlp
.
fc2
.
weight
[:,
rank
*
partition_hidden_dim
:(
rank
+
1
)
*
partition_hidden_dim
]
)
if
rank
==
0
:
model
.
transformer
.
layers
[
i
].
mlp
.
fc2
.
bias
.
copy_
(
model_pt
.
transformer
.
layers
[
i
].
mlp
.
fc2
.
bias
)
model
.
transformer
.
layers
[
i
].
norm1
.
weight
.
copy_
(
model_pt
.
transformer
.
layers
[
i
].
norm1
.
weight
)
model
.
transformer
.
layers
[
i
].
norm1
.
bias
.
copy_
(
model_pt
.
transformer
.
layers
[
i
].
norm1
.
bias
)
model
.
transformer
.
layers
[
i
].
norm2
.
weight
.
copy_
(
model_pt
.
transformer
.
layers
[
i
].
norm2
.
weight
)
model
.
transformer
.
layers
[
i
].
norm2
.
bias
.
copy_
(
model_pt
.
transformer
.
layers
[
i
].
norm2
.
bias
)
# Don't need to copy the lm_head weight since it's tied to the word embedding weight
model
.
load_state_dict
(
shard_state_dict_tp
(
model_pt
.
state_dict
(),
config
,
world_size
,
rank
))
model
.
tie_weights
()
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
dtype
):
out
=
model
(
input_ids
[:,
:
-
1
]).
logits
...
...
@@ -150,62 +113,73 @@ def test_block_parallel(dim, has_pos_emb, world_size, dtype):
allreduce_sequence_parallel_grad
(
model
,
process_group
)
parallel_state
.
destroy_model_parallel
()
grad_dict
=
shard_state_dict_tp
({
k
:
v
.
grad
for
k
,
v
in
model_pt
.
named_parameters
()},
config
,
world_size
,
rank
)
assert
torch
.
allclose
(
model
.
transformer
.
embeddings
.
word_embeddings
.
weight
.
grad
,
model_pt
.
transformer
.
embeddings
.
word_embeddings
.
weight
.
grad
[
rank
*
partition_vocab_size
:(
rank
+
1
)
*
partition_vocab_size
],
grad_dict
[
'
transformer.embeddings.word_embeddings.weight
'
],
rtol
=
rtol
,
atol
=
atol
*
5
)
if
has_pos_emb
:
assert
torch
.
allclose
(
model
.
transformer
.
embeddings
.
position_embeddings
.
weight
.
grad
,
model_pt
.
transformer
.
embeddings
.
position_embeddings
.
weight
.
grad
[:,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
],
grad_dict
[
'
transformer.embeddings.position_embeddings.weight
'
],
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
model
.
transformer
.
ln_0
.
weight
.
grad
,
model_pt
.
transformer
.
ln_0
.
weight
.
grad
,
assert
torch
.
allclose
(
model
.
transformer
.
ln_0
.
weight
.
grad
,
grad_dict
[
'
transformer.ln_0.weight
'
]
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
model
.
transformer
.
ln_0
.
bias
.
grad
,
model_pt
.
transformer
.
ln_0
.
bias
.
grad
,
assert
torch
.
allclose
(
model
.
transformer
.
ln_0
.
bias
.
grad
,
grad_dict
[
'
transformer.ln_0.bias
'
]
,
rtol
=
rtol
,
atol
=
atol
)
for
i
in
range
(
num_layers
):
# if rank == 0: breakpoint()
# torch.distributed.barrier()
assert
torch
.
allclose
(
model
.
transformer
.
layers
[
i
].
mixer
.
Wqkv
.
weight
.
grad
,
rearrange
(
rearrange
(
model_pt
.
transformer
.
layers
[
i
]
.
mixer
.
Wqkv
.
weight
.
grad
,
'(three o) i -> three o i'
,
three
=
3
)[:,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
],
'three o i -> (three o) i'
)
,
grad_dict
[
f
'
transformer.layers
.
{
i
}
.mixer.Wqkv.weight
'
]
,
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
transformer
.
layers
[
i
].
mixer
.
Wqkv
.
bias
.
grad
,
rearrange
(
rearrange
(
model_pt
.
transformer
.
layers
[
i
].
mixer
.
Wqkv
.
bias
.
grad
,
'(three o) -> three o'
,
three
=
3
)[:,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
],
'three o -> (three o)'
),
grad_dict
[
f
'transformer.layers.
{
i
}
.mixer.Wqkv.bias'
],
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
transformer
.
layers
[
i
].
mixer
.
out_proj
.
weight
.
grad
,
model_pt
.
transformer
.
layers
[
i
]
.
mixer
.
out_proj
.
weight
.
grad
[:,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
],
grad_dict
[
f
'
transformer.layers
.
{
i
}
.mixer.out_proj.weight
'
],
rtol
=
rtol
,
atol
=
atol
*
10
)
if
rank
==
0
:
assert
torch
.
allclose
(
model
.
transformer
.
layers
[
i
].
mixer
.
out_proj
.
bias
.
grad
,
model_pt
.
transformer
.
layers
[
i
].
mixer
.
out_proj
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
transformer
.
layers
[
i
].
mixer
.
out_proj
.
bias
.
grad
,
grad_dict
[
f
'transformer.layers.
{
i
}
.mixer.out_proj.bias'
],
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
transformer
.
layers
[
i
].
mlp
.
fc1
.
weight
.
grad
,
model_pt
.
transformer
.
layers
[
i
]
.
mlp
.
fc1
.
weight
.
grad
[
rank
*
partition_hidden_dim
:(
rank
+
1
)
*
partition_hidden_dim
],
grad_dict
[
f
'
transformer.layers
.
{
i
}
.mlp.fc1.weight
'
],
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
transformer
.
layers
[
i
].
mlp
.
fc1
.
bias
.
grad
,
model_pt
.
transformer
.
layers
[
i
]
.
mlp
.
fc1
.
bias
.
grad
[
rank
*
partition_hidden_dim
:(
rank
+
1
)
*
partition_hidden_dim
],
grad_dict
[
f
'
transformer.layers
.
{
i
}
.mlp.fc1.bias
'
],
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
transformer
.
layers
[
i
].
mlp
.
fc2
.
weight
.
grad
,
model_pt
.
transformer
.
layers
[
i
]
.
mlp
.
fc2
.
weight
.
grad
[:,
rank
*
partition_hidden_dim
:(
rank
+
1
)
*
partition_hidden_dim
],
grad_dict
[
f
'
transformer.layers
.
{
i
}
.mlp.fc2.weight
'
],
rtol
=
rtol
,
atol
=
atol
*
10
)
if
rank
==
0
:
assert
torch
.
allclose
(
model
.
transformer
.
layers
[
i
].
mlp
.
fc2
.
bias
.
grad
,
model_pt
.
transformer
.
layers
[
i
].
mlp
.
fc2
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
transformer
.
layers
[
i
].
norm1
.
weight
.
grad
,
model_pt
.
transformer
.
layers
[
i
].
norm1
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
model
.
transformer
.
layers
[
i
].
norm1
.
bias
.
grad
,
model_pt
.
transformer
.
layers
[
i
].
norm1
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
model
.
transformer
.
layers
[
i
].
norm2
.
weight
.
grad
,
model_pt
.
transformer
.
layers
[
i
].
norm2
.
weight
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
model
.
transformer
.
layers
[
i
].
norm2
.
bias
.
grad
,
model_pt
.
transformer
.
layers
[
i
].
norm2
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
model
.
transformer
.
layers
[
i
].
mlp
.
fc2
.
bias
.
grad
,
grad_dict
[
f
'transformer.layers.
{
i
}
.mlp.fc2.bias'
],
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
transformer
.
layers
[
i
].
norm1
.
weight
.
grad
,
grad_dict
[
f
'transformer.layers.
{
i
}
.norm1.weight'
],
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
model
.
transformer
.
layers
[
i
].
norm1
.
bias
.
grad
,
grad_dict
[
f
'transformer.layers.
{
i
}
.norm1.bias'
],
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
model
.
transformer
.
layers
[
i
].
norm2
.
weight
.
grad
,
grad_dict
[
f
'transformer.layers.
{
i
}
.norm2.weight'
],
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
model
.
transformer
.
layers
[
i
].
norm2
.
bias
.
grad
,
grad_dict
[
f
'transformer.layers.
{
i
}
.norm2.bias'
],
rtol
=
rtol
,
atol
=
atol
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment