Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
78225c53
Commit
78225c53
authored
Dec 25, 2022
by
Tri Dao
Browse files
Implement Tensor Parallel for GPT2Embeddings
parent
a8cfe515
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
166 additions
and
7 deletions
+166
-7
flash_attn/modules/embedding.py
flash_attn/modules/embedding.py
+82
-7
tests/modules/test_embedding_parallel.py
tests/modules/test_embedding_parallel.py
+84
-0
No files found.
flash_attn/modules/embedding.py
View file @
78225c53
...
...
@@ -3,18 +3,26 @@
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
from
flash_attn.utils.distributed
import
reduce_scatter
class
GPT2Embeddings
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
padding_idx
=
None
):
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
padding_idx
=
None
,
device
=
None
,
dtype
=
None
):
"""
If max_position_embeddings <= 0, there's no position embeddings
"""
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
)
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
,
**
factory_kwargs
)
self
.
max_position_embeddings
=
max_position_embeddings
if
self
.
max_position_embeddings
>
0
:
self
.
position_embeddings
=
nn
.
Embedding
(
max_position_embeddings
,
embed_dim
)
self
.
position_embeddings
=
nn
.
Embedding
(
max_position_embeddings
,
embed_dim
,
**
factory_kwargs
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
):
"""
...
...
@@ -34,19 +42,23 @@ class GPT2Embeddings(nn.Module):
class
BertEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
type_vocab_size
,
padding_idx
=
None
):
padding_idx
=
None
,
device
=
None
,
dtype
=
None
):
"""
If max_position_embeddings <= 0, there's no position embeddings
If type_vocab_size <= 0, there's no token type embeddings
"""
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
)
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
,
**
factory_kwargs
)
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
if
self
.
max_position_embeddings
>
0
:
self
.
position_embeddings
=
nn
.
Embedding
(
max_position_embeddings
,
embed_dim
)
self
.
position_embeddings
=
nn
.
Embedding
(
max_position_embeddings
,
embed_dim
,
**
factory_kwargs
)
if
self
.
type_vocab_size
>
0
:
self
.
token_type_embeddings
=
nn
.
Embedding
(
type_vocab_size
,
embed_dim
)
self
.
token_type_embeddings
=
nn
.
Embedding
(
type_vocab_size
,
embed_dim
,
**
factory_kwargs
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
):
"""
...
...
@@ -67,3 +79,66 @@ class BertEmbeddings(nn.Module):
token_type_embeddings
=
self
.
token_type_embeddings
(
token_type_ids
)
embeddings
=
embeddings
+
token_type_embeddings
return
embeddings
class
ParallelGPT2Embeddings
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
process_group
,
padding_idx
=
None
,
device
=
None
,
dtype
=
None
):
"""
If max_position_embeddings <= 0, there's no position embeddings
"""
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
if
vocab_size
%
world_size
!=
0
:
raise
ValueError
(
f
'vocab_size (
{
vocab_size
}
) must be divisible by '
f
'world_size (
{
world_size
}
)'
)
if
embed_dim
%
world_size
!=
0
:
raise
ValueError
(
f
'embed_dim (
{
embed_dim
}
) must be divisible by '
f
'world_size (
{
world_size
}
)'
)
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
self
.
process_group
=
process_group
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
//
world_size
,
embed_dim
,
padding_idx
=
padding_idx
,
**
factory_kwargs
)
self
.
max_position_embeddings
=
max_position_embeddings
if
self
.
max_position_embeddings
>
0
:
self
.
position_embeddings
=
nn
.
Embedding
(
max_position_embeddings
,
embed_dim
//
world_size
,
**
factory_kwargs
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
combine_batch_seqlen_dim
=
False
):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
"""
batch_size
,
seqlen
=
input_ids
.
shape
world_size
=
torch
.
distributed
.
get_world_size
(
self
.
process_group
)
if
world_size
<=
1
:
embeddings
=
self
.
word_embeddings
(
input_ids
)
if
self
.
max_position_embeddings
>
0
:
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
embeddings
=
embeddings
+
position_embeddings
if
combine_batch_seqlen_dim
:
embeddings
=
rearrange
(
embeddings
,
'b s d -> (b s) d'
)
return
embeddings
else
:
rank
=
torch
.
distributed
.
get_rank
(
self
.
process_group
)
vocab_size
=
self
.
word_embeddings
.
num_embeddings
vocab_start_index
,
vocab_end_index
=
rank
*
vocab_size
,
(
rank
+
1
)
*
vocab_size
# Create a mask of valid vocab ids (1 means it needs to be masked).
input_ids_mask
=
(
input_ids
<
vocab_start_index
)
|
(
input_ids
>=
vocab_end_index
)
input_ids
=
input_ids
-
vocab_start_index
input_ids
[
input_ids_mask
]
=
0
embeddings
=
self
.
word_embeddings
(
input_ids
)
embeddings
[
input_ids_mask
]
=
0.0
if
self
.
max_position_embeddings
>
0
:
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
partition_dim
=
self
.
position_embeddings
.
embedding_dim
embeddings
[...,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
]
+=
position_embeddings
if
combine_batch_seqlen_dim
:
embeddings
=
rearrange
(
embeddings
,
'b s d -> (b s) d'
)
return
reduce_scatter
(
embeddings
,
self
.
process_group
)
tests/modules/test_embedding_parallel.py
0 → 100644
View file @
78225c53
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_embedding_parallel.py
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
pytest
from
einops
import
rearrange
from
apex.transformer
import
parallel_state
from
flash_attn.modules.embedding
import
GPT2Embeddings
,
ParallelGPT2Embeddings
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize('world_size', [2])
@
pytest
.
mark
.
parametrize
(
'has_pos_emb'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_pos_emb', [True])
@
pytest
.
mark
.
parametrize
(
'dim'
,
[
1024
])
def
test_embedding_parallel
(
dim
,
world_size
,
has_pos_emb
,
dtype
):
vocab_size
=
50264
seqlen
=
2048
assert
vocab_size
%
world_size
==
0
assert
dim
%
world_size
==
0
rtol
,
atol
=
(
3e-3
,
5e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
3e-3
)
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
,
init_method
=
'env://'
)
device
=
f
'cuda:
{
torch
.
distributed
.
get_rank
()
}
'
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
1024
assert
(
batch_size
*
seqlen
)
%
world_size
==
0
input_ids_pt
=
torch
.
randint
(
0
,
vocab_size
,
(
batch_size
,
seqlen
),
device
=
device
)
input_ids
=
input_ids_pt
.
detach
().
clone
()
model_pt
=
GPT2Embeddings
(
dim
,
vocab_size
,
seqlen
if
has_pos_emb
else
0
,
device
=
device
,
dtype
=
dtype
)
model
=
ParallelGPT2Embeddings
(
dim
,
vocab_size
,
seqlen
if
has_pos_emb
else
0
,
parallel_state
.
get_tensor_model_parallel_group
(),
device
=
device
,
dtype
=
dtype
)
partition_vocab_size
=
vocab_size
//
world_size
partition_dim
=
dim
//
world_size
with
torch
.
no_grad
():
model
.
word_embeddings
.
weight
.
copy_
(
model_pt
.
word_embeddings
.
weight
[
rank
*
partition_vocab_size
:(
rank
+
1
)
*
partition_vocab_size
]
)
if
has_pos_emb
:
model
.
position_embeddings
.
weight
.
copy_
(
model_pt
.
position_embeddings
.
weight
[:,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
]
)
out
=
model
(
input_ids
,
combine_batch_seqlen_dim
=
True
)
out_pt
=
rearrange
(
model_pt
(
input_ids
),
'b s d -> (b s) d'
)
partition_batch_dim
=
batch_size
*
seqlen
//
world_size
assert
torch
.
allclose
(
out
,
out_pt
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
],
rtol
=
rtol
,
atol
=
atol
)
g
=
torch
.
randn_like
(
out_pt
)
out_pt
.
backward
(
g
)
out
.
backward
(
g
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
])
parallel_state
.
destroy_model_parallel
()
assert
torch
.
allclose
(
model
.
word_embeddings
.
weight
.
grad
,
model_pt
.
word_embeddings
.
weight
.
grad
[
rank
*
partition_vocab_size
:(
rank
+
1
)
*
partition_vocab_size
],
rtol
=
rtol
,
atol
=
atol
)
if
has_pos_emb
:
assert
torch
.
allclose
(
model
.
position_embeddings
.
weight
.
grad
,
model_pt
.
position_embeddings
.
weight
.
grad
[:,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
],
rtol
=
rtol
,
atol
=
atol
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment