Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f593a563
Unverified
Commit
f593a563
authored
Apr 29, 2022
by
Ziyue Jiang
Committed by
GitHub
Apr 29, 2022
Browse files
[Tensor] add embedding tp1d row (#904)
parent
16122d5f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
108 additions
and
8 deletions
+108
-8
colossalai/tensor/_ops/embedding.py
colossalai/tensor/_ops/embedding.py
+35
-2
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+1
-0
tests/components_to_test/simple_net.py
tests/components_to_test/simple_net.py
+5
-3
tests/test_tensor/test_embedding_tp.py
tests/test_tensor/test_embedding_tp.py
+50
-0
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+17
-3
No files found.
colossalai/tensor/_ops/embedding.py
View file @
f593a563
...
...
@@ -9,7 +9,7 @@ from packaging import version
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ParallelAction
,
ColoTensor
,
ShardPattern
def
colo_embedding_1Dcol
(
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
args
,
kwargs
)
->
ColoTensor
:
# embedding_1Dcol split the weight(lookup table)
# embedding_1Dcol split the weight(lookup table)
to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
parallel_action
=
weight
.
shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1DCol_Embedding
)
if
not
input_tensor
.
is_gathered
():
...
...
@@ -25,6 +25,37 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, args, kwa
output
.
gather
()
return
output
def
colo_embedding_1Drow
(
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
args
,
kwargs
)
->
ColoTensor
:
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
# Find index in this shard and mask those not here
# Reduce all
parallel_action
=
weight
.
shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1DRow_Embedding
)
if
not
input_tensor
.
is_gathered
():
input_tensor
.
gather
()
tensor_parallel_rank
=
gpc
.
get_local_rank
(
parallel_action
.
parallel_mode
)
num_embeddings_per_partition
=
weight
.
size
(
0
)
vocab_start_index
=
tensor_parallel_rank
*
num_embeddings_per_partition
vocab_end_index
=
vocab_start_index
+
num_embeddings_per_partition
# Build the mask.
input_mask
=
(
input_tensor
.
torch_tensor
()
<
vocab_start_index
)
|
\
(
input_tensor
.
torch_tensor
()
>=
vocab_end_index
)
# Mask the input.
# TODO(jzy) masked_input may be an activation managed by ColoTensor.
masked_input
=
input_tensor
.
torch_tensor
().
clone
()
-
vocab_start_index
masked_input
[
input_mask
]
=
0
partial_output
=
torch
.
nn
.
functional
.
embedding
(
masked_input
,
weight
.
torch_tensor
(),
*
args
,
**
kwargs
)
# Mask the output embedding.
partial_output
[
input_mask
,
:]
=
0.
# Reduce across all the model parallel GPUs.
output
=
reduce_input
(
partial_output
,
parallel_action
.
parallel_mode
)
output
=
ColoTensor
.
init_from_torch_tensor
(
output
)
return
output
@
colo_op_impl
(
torch
.
nn
.
functional
.
embedding
)
def
colo_embedding
(
types
,
args
,
kwargs
,
pg
):
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``.
...
...
@@ -48,7 +79,9 @@ def colo_embedding(types, args, kwargs, pg):
return
ColoTensor
.
init_from_torch_tensor
(
output
)
elif
weight
.
shard_spec
.
num_action
==
1
:
# Single Model Parallel Applied
compute_patterns
=
weight
.
shard_spec
.
compute_patterns
if
ComputePattern
.
TP1DCol_Embedding
in
compute_patterns
:
if
ComputePattern
.
TP1DRow_Embedding
in
compute_patterns
:
return
colo_embedding_1Drow
(
input_tensor
,
weight
,
args
,
kwargs
)
elif
ComputePattern
.
TP1DCol_Embedding
in
compute_patterns
:
return
colo_embedding_1Dcol
(
input_tensor
,
weight
,
args
,
kwargs
)
else
:
raise
NotImplementedError
...
...
colossalai/tensor/colo_tensor.py
View file @
f593a563
...
...
@@ -166,6 +166,7 @@ class ColoTensor(object):
dim
=
-
1
self
.
_torch_tensor
=
gather_forward_split_backward
(
self
.
_torch_tensor
,
parallel_action
.
parallel_mode
,
dim
=
dim
)
self
.
_shard_pattern
=
ShardPattern
.
NA
self
.
_size
=
self
.
_torch_tensor
.
size
()
def
is_gathered
(
self
)
->
bool
:
return
self
.
_shard_pattern
==
ShardPattern
.
NA
...
...
tests/components_to_test/simple_net.py
View file @
f593a563
...
...
@@ -5,7 +5,6 @@ from .utils.dummy_data_generator import DummyDataGenerator
from
.registry
import
non_distributed_component_funcs
from
colossalai.utils.cuda
import
get_current_device
class
SimpleNet
(
CheckpointModule
):
"""
In this no-leaf module, it has subordinate nn.modules and a nn.Parameter.
...
...
@@ -13,12 +12,14 @@ class SimpleNet(CheckpointModule):
def
__init__
(
self
,
checkpoint
=
False
)
->
None
:
super
().
__init__
(
checkpoint
=
checkpoint
)
self
.
embed
=
nn
.
Embedding
(
20
,
4
)
self
.
proj1
=
nn
.
Linear
(
4
,
8
)
self
.
ln1
=
nn
.
LayerNorm
(
8
)
self
.
proj2
=
nn
.
Linear
(
8
,
4
)
self
.
ln2
=
nn
.
LayerNorm
(
4
)
def
forward
(
self
,
x
):
x
=
self
.
embed
(
x
)
x
=
self
.
proj1
(
x
)
x
=
self
.
ln1
(
x
)
x
=
self
.
proj2
(
x
)
...
...
@@ -26,11 +27,12 @@ class SimpleNet(CheckpointModule):
return
x
class
DummyDataLoader
(
DummyDataGenerator
):
def
generate
(
self
):
data
=
torch
.
rand
(
16
,
4
,
device
=
get_current_device
())
label
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
16
,),
device
=
get_current_device
())
data
=
torch
.
rand
int
(
low
=
0
,
high
=
20
,
size
=
(
16
,
20
)
,
device
=
get_current_device
())
label
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
16
,
4
),
device
=
get_current_device
())
return
data
,
label
...
...
tests/test_tensor/test_embedding_tp.py
View file @
f593a563
...
...
@@ -65,10 +65,60 @@ def run_embedding_tp1d_col_test():
W_grad
=
torch
.
chunk
(
W_grad
,
DEPTH
,
dim
=-
1
)[
local_rank
]
check_equal
(
W_grad
,
layer
.
weight
.
grad
)
def
run_embedding_tp1d_row_test
():
device
=
get_current_device
()
dtype
=
torch
.
float32
DEPTH
=
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
num_embeddings
=
12
embedding_dim
=
32
local_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
layer_master
=
torch
.
nn
.
Embedding
(
num_embeddings
,
embedding_dim
)
layer
=
torch
.
nn
.
Embedding
(
num_embeddings
,
embedding_dim
)
A_master
=
torch
.
tensor
((
0
,
3
,
6
,
9
),
device
=
device
)
A
=
broadcast_tensor_chunk
(
A_master
,
chunk_size
=
1
)
W_shape
=
(
num_embeddings
,
embedding_dim
)
W_master
=
torch
.
randn
(
W_shape
,
dtype
=
dtype
,
device
=
device
)
W
=
broadcast_tensor_chunk
(
W_master
,
chunk_size
=
1
)
W
.
requires_grad
=
True
# replace the torch nn.Parameters with ColoTensor
sharded_weight
=
ColoTensor
.
init_from_torch_tensor
(
W
)
parallel_action_list
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DRow_Embedding
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
spec
=
TensorSpec
(
parallel_action_list
)
sharded_weight
.
set_spec
(
spec
)
# reshard
replace_parameter_add_grad
(
layer
,
sharded_weight
)
out
=
layer
(
A
)
replace_parameter_add_grad
(
layer_master
,
W_master
)
C_master
=
layer_master
(
A_master
)
C
=
C_master
.
clone
()
check_equal
(
out
,
C
)
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
get_current_device
())
grad
=
broadcast_tensor_chunk
(
grad_master
,
chunk_size
=
1
)
out
.
backward
(
grad
)
grad_master
=
grad_master
.
clone
()
C_master
.
backward
(
grad_master
)
W_grad
=
W_master
.
grad
W_grad
=
torch
.
chunk
(
W_grad
,
DEPTH
,
dim
=
0
)[
local_rank
]
check_equal
(
W_grad
,
layer
.
weight
.
grad
)
def
run_dist
(
rank
,
world_size
,
port
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_embedding_tp1d_col_test
()
run_embedding_tp1d_row_test
()
@
pytest
.
mark
.
dist
@
parameterize
(
'world_size'
,
[
1
,
4
])
...
...
tests/test_tensor/test_model.py
View file @
f593a563
...
...
@@ -47,6 +47,11 @@ def run_1d_col_tp():
]
spec_col
=
TensorSpec
(
parallel_action_list_col
)
parallel_action_list_embedding_col
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DCol_Embedding
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
spec_embedding_col
=
TensorSpec
(
parallel_action_list_embedding_col
)
set_seed
(
1
)
if
rank
==
0
:
model_torch
=
model_builder
(
checkpoint
=
True
)
...
...
@@ -60,6 +65,8 @@ def run_1d_col_tp():
p
.
set_spec
(
spec_col
)
if
'proj2'
in
name
and
'weight'
in
name
:
p
.
set_spec
(
spec_row
)
if
'embed'
in
name
and
'weight'
in
name
:
p
.
set_spec
(
spec_embedding_col
)
model
=
model
.
cuda
()
...
...
@@ -172,6 +179,11 @@ def run_1d_row_tp():
]
spec
=
TensorSpec
(
parallel_action_list
)
parallel_action_list_embedding_row
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DRow_Embedding
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
spec_embedding_row
=
TensorSpec
(
parallel_action_list_embedding_row
)
set_seed
(
1
)
if
rank
==
0
:
model_torch
=
model_builder
(
checkpoint
=
True
)
...
...
@@ -183,6 +195,8 @@ def run_1d_row_tp():
continue
if
'weight'
in
name
and
'LayerNorm'
not
in
name
and
'ln'
not
in
name
and
'embed'
not
in
name
:
p
.
set_spec
(
spec
)
if
'embed'
in
name
and
'weight'
in
name
:
p
.
set_spec
(
spec_embedding_row
)
model
=
model
.
cuda
()
...
...
@@ -227,7 +241,7 @@ def run_dist(rank, world_size, port):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_1d_row_tp
()
run_1d_col_tp
()
@
pytest
.
mark
.
dist
@
parameterize
(
'world_size'
,
[
1
,
4
])
...
...
@@ -238,6 +252,6 @@ def test_simple_net(world_size):
if
__name__
==
'__main__'
:
#
test_simple_net()
test_simple_net
()
# test_model_parameters()
test_colo_optimizer
()
#
test_colo_optimizer()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment