Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
4cab4de5
Commit
4cab4de5
authored
Jan 02, 2023
by
Tri Dao
Browse files
[TP] Put parallel embeddings in separate modules
parent
1ec09ebd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
63 additions
and
37 deletions
+63
-37
flash_attn/modules/embedding.py
flash_attn/modules/embedding.py
+63
-37
No files found.
flash_attn/modules/embedding.py
View file @
4cab4de5
...
...
@@ -2,6 +2,7 @@
import
torch
import
torch.nn
as
nn
from
torch
import
Tensor
from
einops
import
rearrange
...
...
@@ -81,6 +82,51 @@ class BertEmbeddings(nn.Module):
return
embeddings
class
VocabParallelEmbedding
(
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
,
*
args
,
process_group
=
None
,
padding_idx
=
None
,
**
kwargs
):
self
.
process_group
=
process_group
if
process_group
is
not
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
if
num_embeddings
%
world_size
!=
0
:
raise
ValueError
(
f
'num_embeddings (
{
num_embeddings
}
) must be divisible by '
f
'world_size (
{
world_size
}
)'
)
if
world_size
>
1
and
padding_idx
is
not
None
:
raise
RuntimeError
(
'ParallelEmbedding does not support padding_idx'
)
else
:
world_size
=
1
super
().
__init__
(
num_embeddings
//
world_size
,
*
args
,
padding_idx
=
padding_idx
,
**
kwargs
)
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
if
self
.
process_group
is
None
:
return
super
().
forward
(
input
)
else
:
rank
=
torch
.
distributed
.
get_rank
(
self
.
process_group
)
vocab_size
=
self
.
num_embeddings
vocab_start_index
,
vocab_end_index
=
rank
*
vocab_size
,
(
rank
+
1
)
*
vocab_size
# Create a mask of valid vocab ids (1 means it needs to be masked).
input_ids_mask
=
(
input
<
vocab_start_index
)
|
(
input
>=
vocab_end_index
)
input
=
input
-
vocab_start_index
input
[
input_ids_mask
]
=
0
embeddings
=
super
().
forward
(
input
)
embeddings
[
input_ids_mask
]
=
0.0
return
embeddings
class
ColumnParallelEmbedding
(
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
*
args
,
process_group
=
None
,
**
kwargs
):
self
.
process_group
=
process_group
if
process_group
is
not
None
:
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
if
embedding_dim
%
world_size
!=
0
:
raise
ValueError
(
f
'embedding_dim (
{
embedding_dim
}
) must be divisible by '
f
'world_size (
{
world_size
}
)'
)
else
:
world_size
=
1
super
().
__init__
(
num_embeddings
,
embedding_dim
//
world_size
,
*
args
,
**
kwargs
)
class
ParallelGPT2Embeddings
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
process_group
,
...
...
@@ -88,22 +134,17 @@ class ParallelGPT2Embeddings(nn.Module):
"""
If max_position_embeddings <= 0, there's no position embeddings
"""
world_size
=
torch
.
distributed
.
get_world_size
(
process_group
)
if
vocab_size
%
world_size
!=
0
:
raise
ValueError
(
f
'vocab_size (
{
vocab_size
}
) must be divisible by '
f
'world_size (
{
world_size
}
)'
)
if
embed_dim
%
world_size
!=
0
:
raise
ValueError
(
f
'embed_dim (
{
embed_dim
}
) must be divisible by '
f
'world_size (
{
world_size
}
)'
)
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
self
.
process_group
=
process_group
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
//
world_size
,
embed_dim
,
padding_idx
=
padding_idx
,
**
factory_kwargs
)
self
.
word_embeddings
=
VocabParallelEmbedding
(
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
,
process_group
=
process_group
,
**
factory_kwargs
)
self
.
max_position_embeddings
=
max_position_embeddings
if
self
.
max_position_embeddings
>
0
:
self
.
position_embeddings
=
nn
.
Embedding
(
max_position_embeddings
,
embed_dim
//
world_size
,
**
factory_kwargs
self
.
position_embeddings
=
ColumnParallel
Embedding
(
max_position_embeddings
,
embed_dim
,
process_group
=
process_group
,
**
factory_kwargs
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
combine_batch_seqlen_dim
=
False
):
...
...
@@ -113,32 +154,17 @@ class ParallelGPT2Embeddings(nn.Module):
"""
batch_size
,
seqlen
=
input_ids
.
shape
world_size
=
torch
.
distributed
.
get_world_size
(
self
.
process_group
)
if
wor
l
d_
size
<=
1
:
embeddings
=
self
.
word_embeddings
(
input_ids
)
if
self
.
max_position_embeddings
>
0
:
if
position_ids
is
None
:
position_
ids
=
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
embeddings
=
self
.
word_
embeddings
(
input_ids
)
if
self
.
max_position_embeddings
>
0
:
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_
embeddings
=
self
.
position_embeddings
(
position_ids
)
if
world_size
<=
1
:
embeddings
=
embeddings
+
position_embeddings
if
combine_batch_seqlen_dim
:
embeddings
=
rearrange
(
embeddings
,
'b s d -> (b s) d'
)
return
embeddings
else
:
rank
=
torch
.
distributed
.
get_rank
(
self
.
process_group
)
vocab_size
=
self
.
word_embeddings
.
num_embeddings
vocab_start_index
,
vocab_end_index
=
rank
*
vocab_size
,
(
rank
+
1
)
*
vocab_size
# Create a mask of valid vocab ids (1 means it needs to be masked).
input_ids_mask
=
(
input_ids
<
vocab_start_index
)
|
(
input_ids
>=
vocab_end_index
)
input_ids
=
input_ids
-
vocab_start_index
input_ids
[
input_ids_mask
]
=
0
embeddings
=
self
.
word_embeddings
(
input_ids
)
embeddings
[
input_ids_mask
]
=
0.0
if
self
.
max_position_embeddings
>
0
:
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
else
:
partition_dim
=
self
.
position_embeddings
.
embedding_dim
rank
=
torch
.
distributed
.
get_rank
(
self
.
process_group
)
embeddings
[...,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
]
+=
position_embeddings
if
combine_batch_seqlen_dim
:
embeddings
=
rearrange
(
embeddings
,
'b s d -> (b s) d'
)
return
reduce_scatter
(
embeddings
,
self
.
process_group
)
if
combine_batch_seqlen_dim
:
embeddings
=
rearrange
(
embeddings
,
'b s d -> (b s) d'
)
return
embeddings
if
world_size
<=
1
else
reduce_scatter
(
embeddings
,
self
.
process_group
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment