Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
25d6b1db
Unverified
Commit
25d6b1db
authored
Aug 20, 2023
by
Xuechen Li
Committed by
GitHub
Aug 20, 2023
Browse files
handle uneven heads across ranks when combining state_dicts; resolves #467 (#468)
* q * add comment.
parent
d431f167
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
20 deletions
+38
-20
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+38
-20
No files found.
flash_attn/models/gpt.py
View file @
25d6b1db
...
...
@@ -20,16 +20,12 @@ from flash_attn.models.opt import remap_state_dict_hf_opt
from
flash_attn.modules.block
import
Block
,
ParallelBlock
from
flash_attn.modules.embedding
import
GPT2Embeddings
,
ParallelGPT2Embeddings
from
flash_attn.modules.mha
import
MHA
,
ParallelMHA
from
flash_attn.modules.mlp
import
(
FusedMLP
,
GatedMlp
,
Mlp
,
ParallelFusedMLP
,
ParallelGatedMlp
,
ParallelMLP
,
)
from
flash_attn.modules.mlp
import
(
FusedMLP
,
GatedMlp
,
Mlp
,
ParallelFusedMLP
,
ParallelGatedMlp
,
ParallelMLP
)
from
flash_attn.ops.activations
import
sqrelu_fwd
from
flash_attn.utils.distributed
import
all_gather_raw
,
get_dim_for_local_rank
,
sync_shared_params
from
flash_attn.utils.distributed
import
(
all_gather_raw
,
get_dim_for_local_rank
,
sync_shared_params
)
from
flash_attn.utils.generation
import
GenerationMixin
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
...
...
@@ -44,7 +40,8 @@ except ImportError:
dropout_add_layer_norm
=
None
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm_parallel_residual
from
flash_attn.ops.layer_norm
import
\
dropout_add_layer_norm_parallel_residual
except
ImportError
:
dropout_add_layer_norm_parallel_residual
=
None
...
...
@@ -673,6 +670,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
def
shard_state_dict_tp
(
state_dict
,
config
,
world_size
,
rank
):
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
with tensor parallel.
This function modifies state_dict in place.
"""
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
vocab_size
=
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
...
...
@@ -784,11 +783,14 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
return
state_dict
def
combine_state_dicts_tp
(
state_dicts
,
c
onfig
):
"""Convert the state_dict of a GPT model with tensor parallel to
the state_dict of a
standard GPT model.
def
combine_state_dicts_tp
(
state_dicts
:
list
[
dict
[
str
,
torch
.
Tensor
]],
config
:
GPT2C
onfig
):
"""Convert the
list of sharded
state_dict of a GPT model with tensor parallel to
the state_dict of a
standard GPT model.
This function is meant to be the "reverse" of shard_state_dict_tp.
Precondition:
- state_dicts should be ordered in the same way as the shards were created.
"""
world_size
=
len
(
state_dicts
)
keys
=
state_dicts
[
0
].
keys
()
...
...
@@ -812,9 +814,6 @@ def combine_state_dicts_tp(state_dicts, config):
def
combine_qkv_headdim
(
state_dicts
,
state_dict
,
key
):
n_head
=
config
.
n_head
n_head_kv
=
getattr
(
config
,
"n_head_kv"
,
n_head
)
assert
n_head
%
world_size
==
0
and
n_head_kv
%
world_size
==
0
n_head_per_rank
=
n_head
//
world_size
n_head_kv_per_rank
=
n_head_kv
//
world_size
if
key
in
state_dict
:
if
n_head_kv
==
n_head
:
xs
=
[
...
...
@@ -830,18 +829,37 @@ def combine_state_dicts_tp(state_dicts, config):
)
for
s
in
state_dicts
]
n_head_each_rank
=
[
get_dim_for_local_rank
(
n_head
,
world_size
,
local_rank
)
for
local_rank
in
range
(
world_size
)
]
n_head_kv_each_rank
=
[
get_dim_for_local_rank
(
n_head_kv
,
world_size
,
local_rank
)
for
local_rank
in
range
(
world_size
)
]
state_dict
[
key
]
=
rearrange
(
torch
.
cat
(
[
torch
.
cat
([
x
[:
n_head_per_rank
]
for
x
in
xs
],
dim
=
0
),
torch
.
cat
(
[
x
[:
n_head_each_rank
[
rank
]]
for
rank
,
x
in
enumerate
(
xs
)],
dim
=
0
),
torch
.
cat
(
[
x
[
n_head_each_rank
[
rank
]
:
n_head_each_rank
[
rank
]
+
n_head_kv_each_rank
[
rank
]
]
for
rank
,
x
in
enumerate
(
xs
)
],
dim
=
0
,
),
torch
.
cat
(
[
x
[
n_head_
per
_rank
:
n_head_per_
rank
+
n_head_kv_
per_
rank
]
for
x
in
xs
x
[
n_head_
each
_rank
[
rank
]
+
n_head_kv_
each_rank
[
rank
]
:]
for
rank
,
x
in
enumerate
(
xs
)
],
dim
=
0
,
),
torch
.
cat
([
x
[
-
n_head_kv_per_rank
:]
for
x
in
xs
],
dim
=
0
),
],
dim
=
0
,
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment