Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
507c0ad3
Commit
507c0ad3
authored
Jun 16, 2023
by
FoolPlayer
Committed by
Frank Lee
Jul 04, 2023
Browse files
add vocabembedding layer
parent
45d93843
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
100 additions
and
10 deletions
+100
-10
colossalai/shardformer/layer/layers.py
colossalai/shardformer/layer/layers.py
+55
-10
tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py
...hardformer/test_layer/test_vocab_parallel_embedding_1d.py
+45
-0
No files found.
colossalai/shardformer/layer/layers.py
View file @
507c0ad3
...
@@ -139,6 +139,7 @@ class Linear1D_Col(ParallelModule):
...
@@ -139,6 +139,7 @@ class Linear1D_Col(ParallelModule):
with
self
.
randomizer
.
fork_rng
(
enable_cpu
=
True
):
with
self
.
randomizer
.
fork_rng
(
enable_cpu
=
True
):
self
.
reset_parameters
(
weight_initializer
,
bias_initializer
)
self
.
reset_parameters
(
weight_initializer
,
bias_initializer
)
@
staticmethod
def
from_native_module
(
module
:
nn
.
Linear
,
process_group
:
Union
[
ProcessGroup
,
List
[
ProcessGroup
]],
*
args
,
def
from_native_module
(
module
:
nn
.
Linear
,
process_group
:
Union
[
ProcessGroup
,
List
[
ProcessGroup
]],
*
args
,
**
kwargs
)
->
ParallelModule
:
**
kwargs
)
->
ParallelModule
:
r
"""
r
"""
...
@@ -587,6 +588,8 @@ class VocabParallelEmbedding1D(ParallelLayer):
...
@@ -587,6 +588,8 @@ class VocabParallelEmbedding1D(ParallelLayer):
embedding_dim
:
int
,
embedding_dim
:
int
,
padding_idx
:
int
=
None
,
padding_idx
:
int
=
None
,
dtype
:
torch
.
dtype
=
None
,
dtype
:
torch
.
dtype
=
None
,
device
:
torch
.
device
=
None
,
process_group
:
ProcessGroup
=
None
,
weight_initializer
:
Callable
=
init
.
normal_
(),
weight_initializer
:
Callable
=
init
.
normal_
(),
*
args
,
*
args
,
**
kwargs
):
**
kwargs
):
...
@@ -596,21 +599,63 @@ class VocabParallelEmbedding1D(ParallelLayer):
...
@@ -596,21 +599,63 @@ class VocabParallelEmbedding1D(ParallelLayer):
self
.
padding_idx
=
padding_idx
self
.
padding_idx
=
padding_idx
self
.
embed_args
=
args
self
.
embed_args
=
args
self
.
embed_kwargs
=
kwargs
self
.
embed_kwargs
=
kwargs
self
.
process_group
=
process_group
tensor_parallel_size
=
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
tensor_parallel_size
=
dist
.
get_world_size
(
group
=
process_group
)
tensor_parallel_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
tensor_parallel_rank
=
dist
.
get_rank
(
group
=
process_group
)
# self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
self
.
num_embeddings_per_partition
=
num_embeddings
self
.
num_embeddings_per_partition
=
divide
(
num_embeddings
,
tensor_parallel_size
)
self
.
num_embeddings
=
self
.
num_embeddings_per_partition
self
.
vocab_start_index
=
tensor_parallel_rank
*
self
.
num_embeddings_per_partition
self
.
vocab_start_index
=
tensor_parallel_rank
*
self
.
num_embeddings_per_partition
self
.
vocab_end_index
=
self
.
vocab_start_index
+
self
.
num_embeddings_per_partition
self
.
vocab_end_index
=
self
.
vocab_start_index
+
self
.
num_embeddings_per_partition
self
.
weight
=
Parameter
(
self
.
weight
=
Parameter
(
torch
.
empty
((
self
.
num_embeddings_per_partition
,
self
.
embed_dim
),
device
=
get_current_device
(),
dtype
=
dtype
))
torch
.
empty
((
self
.
num_embeddings_per_partition
,
self
.
embed_dim
),
device
=
device
,
dtype
=
dtype
))
# offset the seed with randomizer index and rank
seed
=
torch
.
random
.
initial_seed
()
self
.
randomizer
=
create_randomizer_with_offset
(
seed
,
process_group
=
self
.
process_group
)
with
self
.
randomizer
.
fork_rng
(
enable_cpu
=
True
):
self
.
reset_parameters
(
weight_initializer
)
# self.reset_parameters(weight_initializer)
# self._set_tensor_parallel_attributes()
# set_parallel_input(False)
# env.vocab_parallel = True
@
staticmethod
def
from_native_module
(
module
:
nn
.
Embedding
,
process_group
:
Union
[
ProcessGroup
,
List
[
ProcessGroup
]],
*
args
,
**
kwargs
)
->
ParallelModule
:
r
"""
Convert a native pytorch embedding module to a parallel module.
"""
# get the origin attributes
num_embeddings
=
module
.
num_embeddings
embedding_dim
=
module
.
embedding_dim
padding_idx
=
module
.
padding_idx
device
=
module
.
weight
.
device
# ensure only one process group is used
if
isinstance
(
process_group
,
(
list
,
tuple
)):
assert
len
(
process_group
)
==
1
,
\
f
'Expected only one process group, got
{
len
(
process_group
)
}
.'
process_group
=
process_group
[
0
]
# create the parallel module
vocab_embedding_1d
=
VocabParallelEmbedding1D
(
num_embeddings
=
num_embeddings
,
embedding_dim
=
embedding_dim
,
padding_idx
=
padding_idx
,
device
=
device
,
process_group
=
process_group
,
*
args
,
**
kwargs
)
with
torch
.
no_grad
():
# shard and slice the weight along the vocabulary(num_embeddings) dimension
# the shape of the weight is (num_embeddings, embedding_dim)
shard_weight
=
shard_rowwise
(
module
.
weight
.
data
,
process_group
)
vocab_embedding_1d
.
weight
.
data
.
copy_
(
shard_weight
)
self
.
reset_parameters
(
weight_initializer
)
return
vocab_embedding_1d
self
.
_set_tensor_parallel_attributes
()
set_parallel_input
(
False
)
env
.
vocab_parallel
=
True
def
_set_tensor_parallel_attributes
(
self
):
def
_set_tensor_parallel_attributes
(
self
):
set_tensor_parallel_attribute_by_partition
(
self
.
weight
,
gpc
.
tensor_parallel_size
)
set_tensor_parallel_attribute_by_partition
(
self
.
weight
,
gpc
.
tensor_parallel_size
)
...
@@ -665,5 +710,5 @@ class VocabParallelEmbedding1D(ParallelLayer):
...
@@ -665,5 +710,5 @@ class VocabParallelEmbedding1D(ParallelLayer):
# Mask the output embedding.
# Mask the output embedding.
output_parallel
[
input_mask
,
:]
=
0.
output_parallel
[
input_mask
,
:]
=
0.
# Reduce across all the model parallel GPUs.
# Reduce across all the model parallel GPUs.
output
=
reduce_input
(
output_parallel
,
ParallelMode
.
PARALLEL_1D
)
output
=
reduce_input
(
output_parallel
,
self
.
process_group
)
return
output
return
output
tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py
0 → 100644
View file @
507c0ad3
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
torch.testing
import
assert_close
import
colossalai
from
colossalai.shardformer.layer.layers
import
VocabParallelEmbedding1D
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
def
check_vocab_embedding_1d
():
embedding
=
nn
.
Embedding
(
128
,
32
).
to
(
'cuda'
)
dist_embedding_1d
=
VocabParallelEmbedding1D
.
from_native_module
(
embedding
,
process_group
=
None
)
assert
dist_embedding_1d
.
weight
.
shape
==
torch
.
Size
([
64
,
32
])
assert
dist_embedding_1d
.
num_embeddings
==
64
assert
dist_embedding_1d
.
embed_dim
==
32
# check embedding correctness
x
=
torch
.
randint
(
0
,
128
,
(
4
,
32
)).
to
(
'cuda'
)
org_out
=
embedding
(
x
)
dist_out
=
dist_embedding_1d
(
x
)
assert_close
(
org_out
,
dist_out
)
# check backward correctness
org_out
.
sum
().
backward
()
dist_out
.
sum
().
backward
()
rank
=
dist
.
get_rank
()
target_grad
=
torch
.
chunk
(
embedding
.
weight
.
grad
,
2
,
dim
=
0
)[
rank
]
assert_close
(
target_grad
,
dist_embedding_1d
.
weight
.
grad
)
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
check_vocab_embedding_1d
()
@
rerun_if_address_is_in_use
()
def
test_vocab_embedding
():
spawn
(
run_dist
,
nprocs
=
2
)
if
__name__
==
'__main__'
:
test_vocab_embedding
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment