Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c2947dad
Unverified
Commit
c2947dad
authored
Nov 10, 2022
by
Jiarui Fang
Committed by
GitHub
Nov 10, 2022
Browse files
[inference] streaming Linear 1D Row inference (#1874)
parent
a1416812
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
629 additions
and
554 deletions
+629
-554
colossalai/nn/layer/parallel_1d/layers.py
colossalai/nn/layer/parallel_1d/layers.py
+21
-5
tests/test_fx/test_complete_workflow.py
tests/test_fx/test_complete_workflow.py
+10
-7
tests/test_layers/test_1d/checks_1d/check_layer_1d.py
tests/test_layers/test_1d/checks_1d/check_layer_1d.py
+549
-496
tests/test_layers/test_1d/test_1d.py
tests/test_layers/test_1d/test_1d.py
+49
-46
No files found.
colossalai/nn/layer/parallel_1d/layers.py
View file @
c2947dad
...
...
@@ -597,9 +597,12 @@ class Linear1D_Row(ParallelLayer):
parallel_input
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
)):
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
),
stream_chunk_num
:
int
=
1
):
super
().
__init__
()
self
.
stream_chunk_num
=
stream_chunk_num
# Keep input parameters
self
.
in_features
=
in_features
self
.
out_features
=
out_features
...
...
@@ -617,6 +620,9 @@ class Linear1D_Row(ParallelLayer):
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
self
.
input_size_per_partition
,
**
factory_kwargs
))
if
self
.
stream_chunk_num
>
1
:
# TODO() work for inference only
self
.
chunk_weight
()
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
**
factory_kwargs
))
else
:
...
...
@@ -626,6 +632,9 @@ class Linear1D_Row(ParallelLayer):
self
.
_set_tensor_parallel_attributes
()
set_parallel_input
(
False
)
def
chunk_weight
(
self
):
self
.
weight_list
=
torch
.
chunk
(
self
.
weight
,
self
.
stream_chunk_num
,
dim
=
0
)
def
reset_parameters
(
self
,
weight_initializer
,
bias_initializer
)
->
None
:
fan_in
,
fan_out
=
self
.
in_features
,
self
.
out_features
weight_initializer
(
self
.
weight
,
fan_in
=
fan_in
,
fan_out
=
fan_out
)
...
...
@@ -696,10 +705,17 @@ class Linear1D_Row(ParallelLayer):
input_
.
shape
,
self
.
weight
.
shape
,
self
.
weight
.
shape
[
-
1
]
*
gpc
.
tensor_parallel_size
)
input_
=
split_forward_gather_backward
(
input_
,
ParallelMode
.
PARALLEL_1D
,
dim
=-
1
)
output_parallel
=
F
.
linear
(
input_
,
self
.
weight
)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output
=
reduce_input
(
output_parallel
,
ParallelMode
.
PARALLEL_1D
)
if
self
.
stream_chunk_num
>
1
:
output_parallel_list
=
[
None
for
i
in
range
(
self
.
stream_chunk_num
)]
for
i
in
range
(
self
.
stream_chunk_num
):
output_parallel_list
[
i
]
=
F
.
linear
(
input_
,
self
.
weight_list
[
i
])
output_parallel_list
[
i
]
=
reduce_input
(
output_parallel_list
[
i
],
ParallelMode
.
PARALLEL_1D
)
output
=
torch
.
cat
(
output_parallel_list
,
dim
=-
1
)
else
:
print
(
input_
.
shape
,
self
.
weight
.
shape
)
output_parallel
=
F
.
linear
(
input_
,
self
.
weight
)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output
=
reduce_input
(
output_parallel
,
ParallelMode
.
PARALLEL_1D
)
if
not
self
.
skip_bias_add
:
if
self
.
bias
is
not
None
:
output
=
output
+
self
.
bias
...
...
tests/test_fx/test_complete_workflow.py
View file @
c2947dad
...
...
@@ -32,7 +32,7 @@ class MLP(torch.nn.Module):
return
x
def
run_workflow
(
world_size
):
def
run_workflow
(
world_size
,
dev
):
# initailization
with
LazyInitContext
()
as
ctx
:
model
=
MLP
(
16
)
...
...
@@ -46,7 +46,7 @@ def run_workflow(world_size):
gm
=
torch
.
fx
.
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
# annotate
annotated_gm
=
transformer_mlp_pass
(
gm
,
process_group
=
ProcessGroup
())
annotated_gm
=
transformer_mlp_pass
(
gm
,
process_group
=
ProcessGroup
(
tp_degree
=
world_size
))
annotated_gm
.
recompile
()
# materialization and sharding
...
...
@@ -61,22 +61,25 @@ def run_workflow(world_size):
# test forward to make sure that IR transform will produce the same results
# like how ColoTensor would do it normally
data
=
torch
.
rand
(
4
,
16
)
data
=
torch
.
rand
(
4
,
16
,
device
=
dev
)
non_fx_out
=
model
(
data
)
fx_out
=
annotated_gm
(
data
)
assert
torch
.
equal
(
non_fx_out
,
fx_out
),
f
'
{
non_fx_out
}
vs
{
fx_out
}
'
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
dev
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_workflow
(
world_size
)
run_workflow
(
world_size
,
dev
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'dev'
,
[
'cuda'
,
'cpu'
])
@
rerun_if_address_is_in_use
()
def
test_complete_workflow
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
def
test_complete_workflow
(
world_size
,
dev
):
if
dev
==
'cpu'
and
world_size
>
1
:
return
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
dev
=
dev
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
tests/test_layers/test_1d/checks_1d/check_layer_1d.py
View file @
c2947dad
import
torch
import
torch.distributed
as
dist
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.global_variables
import
tensor_parallel_env
as
env
from
colossalai.nn
import
(
Classifier1D
,
Embedding1D
,
Linear1D_Col
,
Linear1D_Row
,
VanillaClassifier
,
VocabParallelClassifier1D
,
VocabParallelCrossEntropyLoss1D
,
VocabParallelEmbedding1D
)
from
colossalai.utils
import
get_current_device
,
print_rank_0
from
torch.nn
import
Parameter
from
.common
import
BATCH_SIZE
,
DEPTH
,
HIDDEN_SIZE
,
NUM_CLASSES
,
SEQ_LENGTH
,
VOCAB_SIZE
,
check_equal
def
check_linear_col
():
device
=
get_current_device
()
dtype
=
torch
.
float32
INPUT_SIZE
=
HIDDEN_SIZE
OUTPUT_SIZE
=
2
*
HIDDEN_SIZE
i
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
layer
=
Linear1D_Col
(
INPUT_SIZE
,
OUTPUT_SIZE
)
A_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
,
INPUT_SIZE
)
A_master
=
torch
.
randn
(
A_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
A_master
,
src
=
0
)
A
=
A_master
.
clone
()
A
.
requires_grad
=
True
W_shape
=
(
OUTPUT_SIZE
,
INPUT_SIZE
)
W_master
=
torch
.
randn
(
W_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
W_master
,
src
=
0
)
W
=
torch
.
chunk
(
W_master
,
DEPTH
,
dim
=
0
)[
i
]
W
=
W
.
clone
()
W
.
requires_grad
=
True
B_shape
=
(
OUTPUT_SIZE
)
B_master
=
torch
.
randn
(
B_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
B_master
,
src
=
0
)
B
=
torch
.
chunk
(
B_master
,
DEPTH
,
dim
=
0
)[
i
]
B
=
B
.
clone
()
B
.
requires_grad
=
True
layer
.
weight
=
Parameter
(
W
)
layer
.
bias
=
Parameter
(
B
)
out
=
layer
(
A
)
A_master
=
A_master
.
clone
()
A_master
.
requires_grad
=
True
W_master
=
W_master
.
clone
()
W_master
.
requires_grad
=
True
B_master
=
B_master
.
clone
()
B_master
.
requires_grad
=
True
C_master
=
torch
.
matmul
(
A_master
,
W_master
.
transpose
(
0
,
1
))
+
B_master
C
=
torch
.
chunk
(
C_master
,
DEPTH
,
dim
=-
1
)[
i
]
check_equal
(
out
,
C
)
print_rank_0
(
'linear_col forward: pass'
)
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
get_current_device
())
dist
.
broadcast
(
grad_master
,
src
=
0
)
grad
=
torch
.
chunk
(
grad_master
,
DEPTH
,
dim
=-
1
)[
i
]
grad
=
grad
.
clone
()
out
.
backward
(
grad
)
grad_master
=
grad_master
.
clone
()
C_master
.
backward
(
grad_master
)
A_grad
=
A_master
.
grad
check_equal
(
A_grad
,
A
.
grad
)
W_grad
=
W_master
.
grad
W_grad
=
torch
.
chunk
(
W_grad
,
DEPTH
,
dim
=
0
)[
i
]
check_equal
(
W_grad
,
layer
.
weight
.
grad
)
B_grad
=
B_master
.
grad
B_grad
=
torch
.
chunk
(
B_grad
,
DEPTH
,
dim
=
0
)[
i
]
check_equal
(
B_grad
,
layer
.
bias
.
grad
)
print_rank_0
(
'linear_col backward: pass'
)
def
check_linear_row
():
device
=
get_current_device
()
dtype
=
torch
.
float32
INPUT_SIZE
=
HIDDEN_SIZE
OUTPUT_SIZE
=
2
*
HIDDEN_SIZE
i
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
layer
=
Linear1D_Row
(
OUTPUT_SIZE
,
INPUT_SIZE
)
A_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
,
OUTPUT_SIZE
)
A_master
=
torch
.
randn
(
A_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
A_master
,
src
=
0
)
A
=
torch
.
chunk
(
A_master
,
DEPTH
,
dim
=-
1
)[
i
]
A
=
A
.
clone
()
A
.
requires_grad
=
True
W_shape
=
(
INPUT_SIZE
,
OUTPUT_SIZE
)
W_master
=
torch
.
randn
(
W_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
W_master
,
src
=
0
)
W
=
torch
.
chunk
(
W_master
,
DEPTH
,
dim
=-
1
)[
i
]
W
=
W
.
clone
()
W
.
requires_grad
=
True
B_shape
=
(
INPUT_SIZE
)
B_master
=
torch
.
randn
(
B_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
B_master
,
src
=
0
)
B
=
B_master
.
clone
()
B
.
requires_grad
=
True
layer
.
weight
=
Parameter
(
W
)
layer
.
bias
=
Parameter
(
B
)
out
=
layer
(
A
)
A_master
=
A_master
.
clone
()
A_master
.
requires_grad
=
True
W_master
=
W_master
.
clone
()
W_master
.
requires_grad
=
True
B_master
=
B_master
.
clone
()
B_master
.
requires_grad
=
True
C_master
=
torch
.
matmul
(
A_master
,
W_master
.
transpose
(
0
,
1
))
+
B_master
C
=
C_master
.
clone
()
check_equal
(
out
,
C
)
print_rank_0
(
'linear_row forward: pass'
)
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
get_current_device
())
dist
.
broadcast
(
grad_master
,
src
=
0
)
grad
=
grad_master
.
clone
()
out
.
backward
(
grad
)
grad_master
=
grad_master
.
clone
()
C_master
.
backward
(
grad_master
)
A_grad
=
A_master
.
grad
A_grad
=
torch
.
chunk
(
A_grad
,
DEPTH
,
dim
=-
1
)[
i
]
check_equal
(
A_grad
,
A
.
grad
)
W_grad
=
W_master
.
grad
W_grad
=
torch
.
chunk
(
W_grad
,
DEPTH
,
dim
=-
1
)[
i
]
check_equal
(
W_grad
,
layer
.
weight
.
grad
)
B_grad
=
B_master
.
grad
check_equal
(
B_grad
,
layer
.
bias
.
grad
)
print_rank_0
(
'linear_row backward: pass'
)
def
check_embed
():
device
=
get_current_device
()
dtype
=
torch
.
float32
i
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
embed
=
Embedding1D
(
VOCAB_SIZE
,
HIDDEN_SIZE
)
embed
=
embed
.
to
(
dtype
).
to
(
device
)
embed_master
=
torch
.
nn
.
Embedding
(
VOCAB_SIZE
,
HIDDEN_SIZE
)
embed_master
=
embed_master
.
to
(
dtype
).
to
(
device
)
weight_master
=
embed_master
.
weight
.
data
torch
.
distributed
.
broadcast
(
weight_master
,
src
=
0
)
weight
=
torch
.
chunk
(
weight_master
,
DEPTH
,
dim
=-
1
)[
i
]
embed
.
weight
.
data
.
copy_
(
weight
)
A_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
)
A_master
=
torch
.
randint
(
VOCAB_SIZE
,
A_shape
,
device
=
device
)
torch
.
distributed
.
broadcast
(
A_master
,
src
=
0
)
A
=
A_master
.
clone
()
out
=
embed
(
A
)
A_master
=
A_master
.
clone
()
C_master
=
embed_master
(
A_master
)
C
=
C_master
.
clone
()
check_equal
(
out
,
C
)
print_rank_0
(
'embed forward: pass'
)
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
device
)
torch
.
distributed
.
broadcast
(
grad_master
,
src
=
0
)
grad
=
grad_master
.
clone
()
out
.
backward
(
grad
)
grad_master
=
grad_master
.
clone
()
C_master
.
backward
(
grad_master
)
B_grad
=
embed_master
.
weight
.
grad
B_grad
=
torch
.
chunk
(
B_grad
,
DEPTH
,
dim
=-
1
)[
i
]
check_equal
(
B_grad
,
embed
.
weight
.
grad
)
print_rank_0
(
'embed backward: pass'
)
def
check_vocab_parallel_embed
():
device
=
get_current_device
()
dtype
=
torch
.
float32
i
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
embed
=
VocabParallelEmbedding1D
(
VOCAB_SIZE
,
HIDDEN_SIZE
)
embed
=
embed
.
to
(
dtype
).
to
(
device
)
embed_master
=
torch
.
nn
.
Embedding
(
VOCAB_SIZE
,
HIDDEN_SIZE
)
embed_master
=
embed_master
.
to
(
dtype
).
to
(
device
)
weight_master
=
embed_master
.
weight
.
data
torch
.
distributed
.
broadcast
(
weight_master
,
src
=
0
)
weight
=
torch
.
chunk
(
weight_master
,
DEPTH
,
dim
=
0
)[
i
]
embed
.
weight
.
data
.
copy_
(
weight
)
A_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
)
A_master
=
torch
.
randint
(
VOCAB_SIZE
,
A_shape
,
device
=
device
)
torch
.
distributed
.
broadcast
(
A_master
,
src
=
0
)
A
=
A_master
.
clone
()
out
=
embed
(
A
)
A_master
=
A_master
.
clone
()
C_master
=
embed_master
(
A_master
)
C
=
C_master
.
clone
()
check_equal
(
out
,
C
)
print_rank_0
(
'vocab parallel embed forward: pass'
)
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
device
)
torch
.
distributed
.
broadcast
(
grad_master
,
src
=
0
)
grad
=
grad_master
.
clone
()
out
.
backward
(
grad
)
grad_master
=
grad_master
.
clone
()
C_master
.
backward
(
grad_master
)
B_grad
=
embed_master
.
weight
.
grad
B_grad
=
torch
.
chunk
(
B_grad
,
DEPTH
,
dim
=
0
)[
i
]
check_equal
(
B_grad
,
embed
.
weight
.
grad
)
print_rank_0
(
'vocab parallel embed backward: pass'
)
def
check_classifier_no_given_weight
():
device
=
get_current_device
()
dtype
=
torch
.
float32
i
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
env
.
parallel_input_1d
=
False
parallel_input_1d
=
env
.
parallel_input_1d
layer
=
Classifier1D
(
HIDDEN_SIZE
,
NUM_CLASSES
,
bias
=
True
)
layer
.
to
(
dtype
).
to
(
device
)
layer_master
=
VanillaClassifier
(
HIDDEN_SIZE
,
NUM_CLASSES
,
bias
=
True
)
layer_master
=
layer_master
.
to
(
dtype
).
to
(
device
)
W_master
=
layer_master
.
weight
.
data
dist
.
broadcast
(
W_master
,
src
=
0
)
W
=
torch
.
chunk
(
W_master
,
DEPTH
,
dim
=-
1
)[
i
]
layer
.
weight
.
data
.
copy_
(
W
)
B_master
=
layer_master
.
bias
.
data
dist
.
broadcast
(
B_master
,
src
=
0
)
B
=
B_master
.
clone
()
layer
.
bias
.
data
.
copy_
(
B
)
A_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
,
HIDDEN_SIZE
)
A_master
=
torch
.
randn
(
A_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
A_master
,
src
=
0
)
if
parallel_input_1d
:
A
=
torch
.
chunk
(
A_master
,
DEPTH
,
dim
=-
1
)[
i
]
A
=
A
.
clone
()
else
:
A
=
A_master
.
clone
()
A
.
requires_grad
=
True
out
=
layer
(
A
)
A_master
=
A_master
.
clone
()
A_master
.
requires_grad
=
True
C_master
=
layer_master
(
A_master
)
C
=
C_master
.
clone
()
check_equal
(
out
,
C
)
print_rank_0
(
'classifier (no given weight) forward: pass'
)
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
grad_master
,
src
=
0
)
grad
=
grad_master
.
clone
()
out
.
backward
(
grad
)
grad_master
=
grad_master
.
clone
()
C_master
.
backward
(
grad_master
)
A_grad
=
A_master
.
grad
if
parallel_input_1d
:
A_grad
=
torch
.
chunk
(
A_grad
,
DEPTH
,
dim
=-
1
)[
i
]
check_equal
(
A_grad
,
A
.
grad
)
W_grad
=
layer_master
.
weight
.
grad
W_grad
=
torch
.
chunk
(
W_grad
,
DEPTH
,
dim
=-
1
)[
i
]
check_equal
(
W_grad
,
layer
.
weight
.
grad
)
B_grad
=
layer_master
.
bias
.
grad
check_equal
(
B_grad
,
layer
.
bias
.
grad
)
print_rank_0
(
'classifier (no given weight) backward: pass'
)
def
check_vocab_parallel_classifier_no_given_weight
():
device
=
get_current_device
()
dtype
=
torch
.
float32
i
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
layer
=
VocabParallelClassifier1D
(
HIDDEN_SIZE
,
VOCAB_SIZE
,
bias
=
True
)
layer
.
to
(
dtype
).
to
(
device
)
layer_master
=
VanillaClassifier
(
HIDDEN_SIZE
,
VOCAB_SIZE
,
bias
=
True
)
layer_master
=
layer_master
.
to
(
dtype
).
to
(
device
)
W_master
=
layer_master
.
weight
.
data
dist
.
broadcast
(
W_master
,
src
=
0
)
W
=
torch
.
chunk
(
W_master
,
DEPTH
,
dim
=
0
)[
i
]
layer
.
weight
.
data
.
copy_
(
W
)
B_master
=
layer_master
.
bias
.
data
dist
.
broadcast
(
B_master
,
src
=
0
)
B
=
torch
.
chunk
(
B_master
,
DEPTH
,
dim
=
0
)[
i
]
layer
.
bias
.
data
.
copy_
(
B
)
A_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
,
HIDDEN_SIZE
)
A_master
=
torch
.
randn
(
A_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
A_master
,
src
=
0
)
A
=
A_master
.
clone
()
A
.
requires_grad
=
True
out
=
layer
(
A
)
A_master
=
A_master
.
clone
()
A_master
.
requires_grad
=
True
C_master
=
layer_master
(
A_master
)
C
=
torch
.
chunk
(
C_master
,
DEPTH
,
dim
=-
1
)[
i
]
check_equal
(
out
,
C
)
print_rank_0
(
'vocab parallel classifier (no given weight) forward: pass'
)
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
grad_master
,
src
=
0
)
grad
=
torch
.
chunk
(
grad_master
,
DEPTH
,
dim
=-
1
)[
i
]
grad
=
grad
.
clone
()
out
.
backward
(
grad
)
grad_master
=
grad_master
.
clone
()
C_master
.
backward
(
grad_master
)
A_grad
=
A_master
.
grad
check_equal
(
A_grad
,
A
.
grad
)
W_grad
=
layer_master
.
weight
.
grad
W_grad
=
torch
.
chunk
(
W_grad
,
DEPTH
,
dim
=
0
)[
i
]
check_equal
(
W_grad
,
layer
.
weight
.
grad
)
B_grad
=
layer_master
.
bias
.
grad
B_grad
=
torch
.
chunk
(
B_grad
,
DEPTH
,
dim
=
0
)[
i
]
check_equal
(
B_grad
,
layer
.
bias
.
grad
)
print_rank_0
(
'vocab parallel classifier (no given weight) backward: pass'
)
def
check_classifier_given_embed_weight
():
device
=
get_current_device
()
dtype
=
torch
.
float32
i
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
embed
=
Embedding1D
(
VOCAB_SIZE
,
HIDDEN_SIZE
)
embed
=
embed
.
to
(
dtype
).
to
(
device
)
embed_master
=
torch
.
nn
.
Embedding
(
VOCAB_SIZE
,
HIDDEN_SIZE
)
embed_master
=
embed_master
.
to
(
dtype
).
to
(
device
)
weight_master
=
embed_master
.
weight
.
data
torch
.
distributed
.
broadcast
(
weight_master
,
src
=
0
)
weight
=
torch
.
chunk
(
weight_master
,
DEPTH
,
dim
=-
1
)[
i
]
embed
.
weight
.
data
.
copy_
(
weight
)
env
.
parallel_input_1d
=
False
layer
=
Classifier1D
(
HIDDEN_SIZE
,
NUM_CLASSES
,
weight
=
embed
.
weight
,
bias
=
False
)
layer
.
to
(
dtype
).
to
(
device
)
layer_master
=
VanillaClassifier
(
HIDDEN_SIZE
,
NUM_CLASSES
,
weight
=
embed_master
.
weight
,
bias
=
False
)
layer_master
=
layer_master
.
to
(
dtype
).
to
(
device
)
A_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
)
A_master
=
torch
.
randint
(
VOCAB_SIZE
,
A_shape
,
device
=
device
)
torch
.
distributed
.
broadcast
(
A_master
,
src
=
0
)
A
=
A_master
.
clone
()
out
=
layer
(
embed
(
A
))
A_master
=
A_master
.
clone
()
C_master
=
layer_master
(
embed_master
(
A_master
))
C
=
C_master
.
clone
()
check_equal
(
out
,
C
)
print_rank_0
(
'classifier (given embed weight) forward: pass'
)
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
grad_master
,
src
=
0
)
grad
=
grad_master
.
clone
()
out
.
backward
(
grad
)
grad_master
=
grad_master
.
clone
()
C_master
.
backward
(
grad_master
)
W_grad
=
embed_master
.
weight
.
grad
W_grad
=
torch
.
chunk
(
W_grad
,
DEPTH
,
dim
=-
1
)[
i
]
check_equal
(
W_grad
,
embed
.
weight
.
grad
)
print_rank_0
(
'classifier (given embed weight) backward: pass'
)
def
check_vocab_parallel_classifier_given_embed_weight
():
device
=
get_current_device
()
dtype
=
torch
.
float32
i
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
embed
=
VocabParallelEmbedding1D
(
VOCAB_SIZE
,
HIDDEN_SIZE
)
embed
=
embed
.
to
(
dtype
).
to
(
device
)
embed_master
=
torch
.
nn
.
Embedding
(
VOCAB_SIZE
,
HIDDEN_SIZE
)
embed_master
=
embed_master
.
to
(
dtype
).
to
(
device
)
weight_master
=
embed_master
.
weight
.
data
torch
.
distributed
.
broadcast
(
weight_master
,
src
=
0
)
weight
=
torch
.
chunk
(
weight_master
,
DEPTH
,
dim
=
0
)[
i
]
embed
.
weight
.
data
.
copy_
(
weight
)
env
.
parallel_input_1d
=
False
layer
=
VocabParallelClassifier1D
(
HIDDEN_SIZE
,
NUM_CLASSES
,
weight
=
embed
.
weight
,
bias
=
False
)
layer
.
to
(
dtype
).
to
(
device
)
layer_master
=
VanillaClassifier
(
HIDDEN_SIZE
,
NUM_CLASSES
,
weight
=
embed_master
.
weight
,
bias
=
False
)
layer_master
=
layer_master
.
to
(
dtype
).
to
(
device
)
A_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
)
A_master
=
torch
.
randint
(
VOCAB_SIZE
,
A_shape
,
device
=
device
)
torch
.
distributed
.
broadcast
(
A_master
,
src
=
0
)
A
=
A_master
.
clone
()
out
=
layer
(
embed
(
A
))
A_master
=
A_master
.
clone
()
C_master
=
layer_master
(
embed_master
(
A_master
))
C
=
torch
.
chunk
(
C_master
,
DEPTH
,
dim
=-
1
)[
i
]
check_equal
(
out
,
C
)
print_rank_0
(
'vocab parallel classifier (given embed weight) forward: pass'
)
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
grad_master
,
src
=
0
)
grad
=
torch
.
chunk
(
grad_master
,
DEPTH
,
dim
=-
1
)[
i
]
grad
=
grad
.
clone
()
out
.
backward
(
grad
)
grad_master
=
grad_master
.
clone
()
C_master
.
backward
(
grad_master
)
W_grad
=
embed_master
.
weight
.
grad
W_grad
=
torch
.
chunk
(
W_grad
,
DEPTH
,
dim
=
0
)[
i
]
check_equal
(
W_grad
,
embed
.
weight
.
grad
)
print_rank_0
(
'vocab parallel classifier (given embed weight) backward: pass'
)
def
check_vocab_parallel_loss
():
device
=
get_current_device
()
dtype
=
torch
.
float32
i
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
criterion
=
VocabParallelCrossEntropyLoss1D
()
criterion_master
=
torch
.
nn
.
CrossEntropyLoss
()
out_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
,
NUM_CLASSES
)
out_master
=
torch
.
randn
(
out_shape
,
dtype
=
dtype
,
device
=
device
)
target_master
=
torch
.
randint
(
NUM_CLASSES
,
(
BATCH_SIZE
,
SEQ_LENGTH
),
dtype
=
torch
.
long
,
device
=
device
)
torch
.
distributed
.
broadcast
(
out_master
,
src
=
0
)
torch
.
distributed
.
broadcast
(
target_master
,
src
=
0
)
out
=
torch
.
chunk
(
out_master
,
DEPTH
,
dim
=-
1
)[
i
]
out
=
out
.
clone
()
out
.
requires_grad
=
True
loss
=
criterion
(
out
,
target_master
)
out_master
=
out_master
.
clone
()
out_master
.
requires_grad
=
True
loss_master
=
criterion_master
(
out_master
,
target_master
)
check_equal
(
loss
,
loss_master
)
print_rank_0
(
'vocab parallel loss forward: pass'
)
loss
.
backward
()
loss_master
.
backward
()
out_grad
=
out_master
.
grad
out_grad
=
torch
.
chunk
(
out_grad
,
DEPTH
,
dim
=-
1
)[
i
]
check_equal
(
out_grad
,
out
.
grad
)
print_rank_0
(
'vocab parallel loss backward: pass'
)
import
torch
import
torch.distributed
as
dist
from
torch.nn
import
Parameter
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.global_variables
import
tensor_parallel_env
as
env
from
colossalai.nn
import
(
Classifier1D
,
Embedding1D
,
Linear1D_Col
,
Linear1D_Row
,
VanillaClassifier
,
VocabParallelClassifier1D
,
VocabParallelCrossEntropyLoss1D
,
VocabParallelEmbedding1D
,
)
from
colossalai.utils
import
get_current_device
,
print_rank_0
from
.common
import
BATCH_SIZE
,
DEPTH
,
HIDDEN_SIZE
,
NUM_CLASSES
,
SEQ_LENGTH
,
VOCAB_SIZE
,
check_equal
def
check_linear_col
():
device
=
get_current_device
()
dtype
=
torch
.
float32
INPUT_SIZE
=
HIDDEN_SIZE
OUTPUT_SIZE
=
2
*
HIDDEN_SIZE
i
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
layer
=
Linear1D_Col
(
INPUT_SIZE
,
OUTPUT_SIZE
)
A_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
,
INPUT_SIZE
)
A_master
=
torch
.
randn
(
A_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
A_master
,
src
=
0
)
A
=
A_master
.
clone
()
A
.
requires_grad
=
True
W_shape
=
(
OUTPUT_SIZE
,
INPUT_SIZE
)
W_master
=
torch
.
randn
(
W_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
W_master
,
src
=
0
)
W
=
torch
.
chunk
(
W_master
,
DEPTH
,
dim
=
0
)[
i
]
W
=
W
.
clone
()
W
.
requires_grad
=
True
B_shape
=
(
OUTPUT_SIZE
)
B_master
=
torch
.
randn
(
B_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
B_master
,
src
=
0
)
B
=
torch
.
chunk
(
B_master
,
DEPTH
,
dim
=
0
)[
i
]
B
=
B
.
clone
()
B
.
requires_grad
=
True
layer
.
weight
=
Parameter
(
W
)
layer
.
bias
=
Parameter
(
B
)
out
=
layer
(
A
)
A_master
=
A_master
.
clone
()
A_master
.
requires_grad
=
True
W_master
=
W_master
.
clone
()
W_master
.
requires_grad
=
True
B_master
=
B_master
.
clone
()
B_master
.
requires_grad
=
True
C_master
=
torch
.
matmul
(
A_master
,
W_master
.
transpose
(
0
,
1
))
+
B_master
C
=
torch
.
chunk
(
C_master
,
DEPTH
,
dim
=-
1
)[
i
]
check_equal
(
out
,
C
)
print_rank_0
(
'linear_col forward: pass'
)
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
get_current_device
())
dist
.
broadcast
(
grad_master
,
src
=
0
)
grad
=
torch
.
chunk
(
grad_master
,
DEPTH
,
dim
=-
1
)[
i
]
grad
=
grad
.
clone
()
out
.
backward
(
grad
)
grad_master
=
grad_master
.
clone
()
C_master
.
backward
(
grad_master
)
A_grad
=
A_master
.
grad
check_equal
(
A_grad
,
A
.
grad
)
W_grad
=
W_master
.
grad
W_grad
=
torch
.
chunk
(
W_grad
,
DEPTH
,
dim
=
0
)[
i
]
check_equal
(
W_grad
,
layer
.
weight
.
grad
)
B_grad
=
B_master
.
grad
B_grad
=
torch
.
chunk
(
B_grad
,
DEPTH
,
dim
=
0
)[
i
]
check_equal
(
B_grad
,
layer
.
bias
.
grad
)
print_rank_0
(
'linear_col backward: pass'
)
def
check_linear_row
():
device
=
get_current_device
()
dtype
=
torch
.
float32
INPUT_SIZE
=
HIDDEN_SIZE
OUTPUT_SIZE
=
2
*
HIDDEN_SIZE
i
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
layer
=
Linear1D_Row
(
OUTPUT_SIZE
,
INPUT_SIZE
)
A_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
,
OUTPUT_SIZE
)
A_master
=
torch
.
randn
(
A_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
A_master
,
src
=
0
)
A
=
torch
.
chunk
(
A_master
,
DEPTH
,
dim
=-
1
)[
i
]
A
=
A
.
clone
()
A
.
requires_grad
=
True
W_shape
=
(
INPUT_SIZE
,
OUTPUT_SIZE
)
W_master
=
torch
.
randn
(
W_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
W_master
,
src
=
0
)
W
=
torch
.
chunk
(
W_master
,
DEPTH
,
dim
=-
1
)[
i
]
W
=
W
.
clone
()
W
.
requires_grad
=
True
B_shape
=
(
INPUT_SIZE
)
B_master
=
torch
.
randn
(
B_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
B_master
,
src
=
0
)
B
=
B_master
.
clone
()
B
.
requires_grad
=
True
layer
.
weight
=
Parameter
(
W
)
layer
.
bias
=
Parameter
(
B
)
out
=
layer
(
A
)
A_master
=
A_master
.
clone
()
A_master
.
requires_grad
=
True
W_master
=
W_master
.
clone
()
W_master
.
requires_grad
=
True
B_master
=
B_master
.
clone
()
B_master
.
requires_grad
=
True
C_master
=
torch
.
matmul
(
A_master
,
W_master
.
transpose
(
0
,
1
))
+
B_master
C
=
C_master
.
clone
()
check_equal
(
out
,
C
)
print_rank_0
(
'linear_row forward: pass'
)
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
get_current_device
())
dist
.
broadcast
(
grad_master
,
src
=
0
)
grad
=
grad_master
.
clone
()
out
.
backward
(
grad
)
grad_master
=
grad_master
.
clone
()
C_master
.
backward
(
grad_master
)
A_grad
=
A_master
.
grad
A_grad
=
torch
.
chunk
(
A_grad
,
DEPTH
,
dim
=-
1
)[
i
]
check_equal
(
A_grad
,
A
.
grad
)
W_grad
=
W_master
.
grad
W_grad
=
torch
.
chunk
(
W_grad
,
DEPTH
,
dim
=-
1
)[
i
]
check_equal
(
W_grad
,
layer
.
weight
.
grad
)
B_grad
=
B_master
.
grad
check_equal
(
B_grad
,
layer
.
bias
.
grad
)
print_rank_0
(
'linear_row backward: pass'
)
def
check_embed
():
device
=
get_current_device
()
dtype
=
torch
.
float32
i
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
embed
=
Embedding1D
(
VOCAB_SIZE
,
HIDDEN_SIZE
)
embed
=
embed
.
to
(
dtype
).
to
(
device
)
embed_master
=
torch
.
nn
.
Embedding
(
VOCAB_SIZE
,
HIDDEN_SIZE
)
embed_master
=
embed_master
.
to
(
dtype
).
to
(
device
)
weight_master
=
embed_master
.
weight
.
data
torch
.
distributed
.
broadcast
(
weight_master
,
src
=
0
)
weight
=
torch
.
chunk
(
weight_master
,
DEPTH
,
dim
=-
1
)[
i
]
embed
.
weight
.
data
.
copy_
(
weight
)
A_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
)
A_master
=
torch
.
randint
(
VOCAB_SIZE
,
A_shape
,
device
=
device
)
torch
.
distributed
.
broadcast
(
A_master
,
src
=
0
)
A
=
A_master
.
clone
()
out
=
embed
(
A
)
A_master
=
A_master
.
clone
()
C_master
=
embed_master
(
A_master
)
C
=
C_master
.
clone
()
check_equal
(
out
,
C
)
print_rank_0
(
'embed forward: pass'
)
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
device
)
torch
.
distributed
.
broadcast
(
grad_master
,
src
=
0
)
grad
=
grad_master
.
clone
()
out
.
backward
(
grad
)
grad_master
=
grad_master
.
clone
()
C_master
.
backward
(
grad_master
)
B_grad
=
embed_master
.
weight
.
grad
B_grad
=
torch
.
chunk
(
B_grad
,
DEPTH
,
dim
=-
1
)[
i
]
check_equal
(
B_grad
,
embed
.
weight
.
grad
)
print_rank_0
(
'embed backward: pass'
)
def
check_vocab_parallel_embed
():
device
=
get_current_device
()
dtype
=
torch
.
float32
i
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
embed
=
VocabParallelEmbedding1D
(
VOCAB_SIZE
,
HIDDEN_SIZE
)
embed
=
embed
.
to
(
dtype
).
to
(
device
)
embed_master
=
torch
.
nn
.
Embedding
(
VOCAB_SIZE
,
HIDDEN_SIZE
)
embed_master
=
embed_master
.
to
(
dtype
).
to
(
device
)
weight_master
=
embed_master
.
weight
.
data
torch
.
distributed
.
broadcast
(
weight_master
,
src
=
0
)
weight
=
torch
.
chunk
(
weight_master
,
DEPTH
,
dim
=
0
)[
i
]
embed
.
weight
.
data
.
copy_
(
weight
)
A_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
)
A_master
=
torch
.
randint
(
VOCAB_SIZE
,
A_shape
,
device
=
device
)
torch
.
distributed
.
broadcast
(
A_master
,
src
=
0
)
A
=
A_master
.
clone
()
out
=
embed
(
A
)
A_master
=
A_master
.
clone
()
C_master
=
embed_master
(
A_master
)
C
=
C_master
.
clone
()
check_equal
(
out
,
C
)
print_rank_0
(
'vocab parallel embed forward: pass'
)
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
device
)
torch
.
distributed
.
broadcast
(
grad_master
,
src
=
0
)
grad
=
grad_master
.
clone
()
out
.
backward
(
grad
)
grad_master
=
grad_master
.
clone
()
C_master
.
backward
(
grad_master
)
B_grad
=
embed_master
.
weight
.
grad
B_grad
=
torch
.
chunk
(
B_grad
,
DEPTH
,
dim
=
0
)[
i
]
check_equal
(
B_grad
,
embed
.
weight
.
grad
)
print_rank_0
(
'vocab parallel embed backward: pass'
)
def
check_classifier_no_given_weight
():
device
=
get_current_device
()
dtype
=
torch
.
float32
i
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
env
.
parallel_input_1d
=
False
parallel_input_1d
=
env
.
parallel_input_1d
layer
=
Classifier1D
(
HIDDEN_SIZE
,
NUM_CLASSES
,
bias
=
True
)
layer
.
to
(
dtype
).
to
(
device
)
layer_master
=
VanillaClassifier
(
HIDDEN_SIZE
,
NUM_CLASSES
,
bias
=
True
)
layer_master
=
layer_master
.
to
(
dtype
).
to
(
device
)
W_master
=
layer_master
.
weight
.
data
dist
.
broadcast
(
W_master
,
src
=
0
)
W
=
torch
.
chunk
(
W_master
,
DEPTH
,
dim
=-
1
)[
i
]
layer
.
weight
.
data
.
copy_
(
W
)
B_master
=
layer_master
.
bias
.
data
dist
.
broadcast
(
B_master
,
src
=
0
)
B
=
B_master
.
clone
()
layer
.
bias
.
data
.
copy_
(
B
)
A_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
,
HIDDEN_SIZE
)
A_master
=
torch
.
randn
(
A_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
A_master
,
src
=
0
)
if
parallel_input_1d
:
A
=
torch
.
chunk
(
A_master
,
DEPTH
,
dim
=-
1
)[
i
]
A
=
A
.
clone
()
else
:
A
=
A_master
.
clone
()
A
.
requires_grad
=
True
out
=
layer
(
A
)
A_master
=
A_master
.
clone
()
A_master
.
requires_grad
=
True
C_master
=
layer_master
(
A_master
)
C
=
C_master
.
clone
()
check_equal
(
out
,
C
)
print_rank_0
(
'classifier (no given weight) forward: pass'
)
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
grad_master
,
src
=
0
)
grad
=
grad_master
.
clone
()
out
.
backward
(
grad
)
grad_master
=
grad_master
.
clone
()
C_master
.
backward
(
grad_master
)
A_grad
=
A_master
.
grad
if
parallel_input_1d
:
A_grad
=
torch
.
chunk
(
A_grad
,
DEPTH
,
dim
=-
1
)[
i
]
check_equal
(
A_grad
,
A
.
grad
)
W_grad
=
layer_master
.
weight
.
grad
W_grad
=
torch
.
chunk
(
W_grad
,
DEPTH
,
dim
=-
1
)[
i
]
check_equal
(
W_grad
,
layer
.
weight
.
grad
)
B_grad
=
layer_master
.
bias
.
grad
check_equal
(
B_grad
,
layer
.
bias
.
grad
)
print_rank_0
(
'classifier (no given weight) backward: pass'
)
def
check_vocab_parallel_classifier_no_given_weight
():
device
=
get_current_device
()
dtype
=
torch
.
float32
i
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
layer
=
VocabParallelClassifier1D
(
HIDDEN_SIZE
,
VOCAB_SIZE
,
bias
=
True
)
layer
.
to
(
dtype
).
to
(
device
)
layer_master
=
VanillaClassifier
(
HIDDEN_SIZE
,
VOCAB_SIZE
,
bias
=
True
)
layer_master
=
layer_master
.
to
(
dtype
).
to
(
device
)
W_master
=
layer_master
.
weight
.
data
dist
.
broadcast
(
W_master
,
src
=
0
)
W
=
torch
.
chunk
(
W_master
,
DEPTH
,
dim
=
0
)[
i
]
layer
.
weight
.
data
.
copy_
(
W
)
B_master
=
layer_master
.
bias
.
data
dist
.
broadcast
(
B_master
,
src
=
0
)
B
=
torch
.
chunk
(
B_master
,
DEPTH
,
dim
=
0
)[
i
]
layer
.
bias
.
data
.
copy_
(
B
)
A_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
,
HIDDEN_SIZE
)
A_master
=
torch
.
randn
(
A_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
A_master
,
src
=
0
)
A
=
A_master
.
clone
()
A
.
requires_grad
=
True
out
=
layer
(
A
)
A_master
=
A_master
.
clone
()
A_master
.
requires_grad
=
True
C_master
=
layer_master
(
A_master
)
C
=
torch
.
chunk
(
C_master
,
DEPTH
,
dim
=-
1
)[
i
]
check_equal
(
out
,
C
)
print_rank_0
(
'vocab parallel classifier (no given weight) forward: pass'
)
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
grad_master
,
src
=
0
)
grad
=
torch
.
chunk
(
grad_master
,
DEPTH
,
dim
=-
1
)[
i
]
grad
=
grad
.
clone
()
out
.
backward
(
grad
)
grad_master
=
grad_master
.
clone
()
C_master
.
backward
(
grad_master
)
A_grad
=
A_master
.
grad
check_equal
(
A_grad
,
A
.
grad
)
W_grad
=
layer_master
.
weight
.
grad
W_grad
=
torch
.
chunk
(
W_grad
,
DEPTH
,
dim
=
0
)[
i
]
check_equal
(
W_grad
,
layer
.
weight
.
grad
)
B_grad
=
layer_master
.
bias
.
grad
B_grad
=
torch
.
chunk
(
B_grad
,
DEPTH
,
dim
=
0
)[
i
]
check_equal
(
B_grad
,
layer
.
bias
.
grad
)
print_rank_0
(
'vocab parallel classifier (no given weight) backward: pass'
)
def
check_classifier_given_embed_weight
():
device
=
get_current_device
()
dtype
=
torch
.
float32
i
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
embed
=
Embedding1D
(
VOCAB_SIZE
,
HIDDEN_SIZE
)
embed
=
embed
.
to
(
dtype
).
to
(
device
)
embed_master
=
torch
.
nn
.
Embedding
(
VOCAB_SIZE
,
HIDDEN_SIZE
)
embed_master
=
embed_master
.
to
(
dtype
).
to
(
device
)
weight_master
=
embed_master
.
weight
.
data
torch
.
distributed
.
broadcast
(
weight_master
,
src
=
0
)
weight
=
torch
.
chunk
(
weight_master
,
DEPTH
,
dim
=-
1
)[
i
]
embed
.
weight
.
data
.
copy_
(
weight
)
env
.
parallel_input_1d
=
False
layer
=
Classifier1D
(
HIDDEN_SIZE
,
NUM_CLASSES
,
weight
=
embed
.
weight
,
bias
=
False
)
layer
.
to
(
dtype
).
to
(
device
)
layer_master
=
VanillaClassifier
(
HIDDEN_SIZE
,
NUM_CLASSES
,
weight
=
embed_master
.
weight
,
bias
=
False
)
layer_master
=
layer_master
.
to
(
dtype
).
to
(
device
)
A_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
)
A_master
=
torch
.
randint
(
VOCAB_SIZE
,
A_shape
,
device
=
device
)
torch
.
distributed
.
broadcast
(
A_master
,
src
=
0
)
A
=
A_master
.
clone
()
out
=
layer
(
embed
(
A
))
A_master
=
A_master
.
clone
()
C_master
=
layer_master
(
embed_master
(
A_master
))
C
=
C_master
.
clone
()
check_equal
(
out
,
C
)
print_rank_0
(
'classifier (given embed weight) forward: pass'
)
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
grad_master
,
src
=
0
)
grad
=
grad_master
.
clone
()
out
.
backward
(
grad
)
grad_master
=
grad_master
.
clone
()
C_master
.
backward
(
grad_master
)
W_grad
=
embed_master
.
weight
.
grad
W_grad
=
torch
.
chunk
(
W_grad
,
DEPTH
,
dim
=-
1
)[
i
]
check_equal
(
W_grad
,
embed
.
weight
.
grad
)
print_rank_0
(
'classifier (given embed weight) backward: pass'
)
def
check_vocab_parallel_classifier_given_embed_weight
():
device
=
get_current_device
()
dtype
=
torch
.
float32
i
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
embed
=
VocabParallelEmbedding1D
(
VOCAB_SIZE
,
HIDDEN_SIZE
)
embed
=
embed
.
to
(
dtype
).
to
(
device
)
embed_master
=
torch
.
nn
.
Embedding
(
VOCAB_SIZE
,
HIDDEN_SIZE
)
embed_master
=
embed_master
.
to
(
dtype
).
to
(
device
)
weight_master
=
embed_master
.
weight
.
data
torch
.
distributed
.
broadcast
(
weight_master
,
src
=
0
)
weight
=
torch
.
chunk
(
weight_master
,
DEPTH
,
dim
=
0
)[
i
]
embed
.
weight
.
data
.
copy_
(
weight
)
env
.
parallel_input_1d
=
False
layer
=
VocabParallelClassifier1D
(
HIDDEN_SIZE
,
NUM_CLASSES
,
weight
=
embed
.
weight
,
bias
=
False
)
layer
.
to
(
dtype
).
to
(
device
)
layer_master
=
VanillaClassifier
(
HIDDEN_SIZE
,
NUM_CLASSES
,
weight
=
embed_master
.
weight
,
bias
=
False
)
layer_master
=
layer_master
.
to
(
dtype
).
to
(
device
)
A_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
)
A_master
=
torch
.
randint
(
VOCAB_SIZE
,
A_shape
,
device
=
device
)
torch
.
distributed
.
broadcast
(
A_master
,
src
=
0
)
A
=
A_master
.
clone
()
out
=
layer
(
embed
(
A
))
A_master
=
A_master
.
clone
()
C_master
=
layer_master
(
embed_master
(
A_master
))
C
=
torch
.
chunk
(
C_master
,
DEPTH
,
dim
=-
1
)[
i
]
check_equal
(
out
,
C
)
print_rank_0
(
'vocab parallel classifier (given embed weight) forward: pass'
)
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
grad_master
,
src
=
0
)
grad
=
torch
.
chunk
(
grad_master
,
DEPTH
,
dim
=-
1
)[
i
]
grad
=
grad
.
clone
()
out
.
backward
(
grad
)
grad_master
=
grad_master
.
clone
()
C_master
.
backward
(
grad_master
)
W_grad
=
embed_master
.
weight
.
grad
W_grad
=
torch
.
chunk
(
W_grad
,
DEPTH
,
dim
=
0
)[
i
]
check_equal
(
W_grad
,
embed
.
weight
.
grad
)
print_rank_0
(
'vocab parallel classifier (given embed weight) backward: pass'
)
def
check_vocab_parallel_loss
():
device
=
get_current_device
()
dtype
=
torch
.
float32
i
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
criterion
=
VocabParallelCrossEntropyLoss1D
()
criterion_master
=
torch
.
nn
.
CrossEntropyLoss
()
out_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
,
NUM_CLASSES
)
out_master
=
torch
.
randn
(
out_shape
,
dtype
=
dtype
,
device
=
device
)
target_master
=
torch
.
randint
(
NUM_CLASSES
,
(
BATCH_SIZE
,
SEQ_LENGTH
),
dtype
=
torch
.
long
,
device
=
device
)
torch
.
distributed
.
broadcast
(
out_master
,
src
=
0
)
torch
.
distributed
.
broadcast
(
target_master
,
src
=
0
)
out
=
torch
.
chunk
(
out_master
,
DEPTH
,
dim
=-
1
)[
i
]
out
=
out
.
clone
()
out
.
requires_grad
=
True
loss
=
criterion
(
out
,
target_master
)
out_master
=
out_master
.
clone
()
out_master
.
requires_grad
=
True
loss_master
=
criterion_master
(
out_master
,
target_master
)
check_equal
(
loss
,
loss_master
)
print_rank_0
(
'vocab parallel loss forward: pass'
)
loss
.
backward
()
loss_master
.
backward
()
out_grad
=
out_master
.
grad
out_grad
=
torch
.
chunk
(
out_grad
,
DEPTH
,
dim
=-
1
)[
i
]
check_equal
(
out_grad
,
out
.
grad
)
print_rank_0
(
'vocab parallel loss backward: pass'
)
@
torch
.
no_grad
()
def
check_linear_row_stream_inference
():
device
=
get_current_device
()
dtype
=
torch
.
float32
INPUT_SIZE
=
HIDDEN_SIZE
OUTPUT_SIZE
=
2
*
HIDDEN_SIZE
i
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
assert
HIDDEN_SIZE
%
2
==
0
layer
=
Linear1D_Row
(
OUTPUT_SIZE
,
INPUT_SIZE
,
stream_chunk_num
=
2
)
A_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
,
OUTPUT_SIZE
)
A_master
=
torch
.
randn
(
A_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
A_master
,
src
=
0
)
A
=
torch
.
chunk
(
A_master
,
DEPTH
,
dim
=-
1
)[
i
]
A
=
A
.
clone
()
W_shape
=
(
INPUT_SIZE
,
OUTPUT_SIZE
)
W_master
=
torch
.
randn
(
W_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
W_master
,
src
=
0
)
W
=
torch
.
chunk
(
W_master
,
DEPTH
,
dim
=-
1
)[
i
]
W
=
W
.
clone
()
B_shape
=
(
INPUT_SIZE
)
B_master
=
torch
.
randn
(
B_shape
,
dtype
=
dtype
,
device
=
device
)
dist
.
broadcast
(
B_master
,
src
=
0
)
B
=
B_master
.
clone
()
layer
.
weight
=
Parameter
(
W
)
layer
.
bias
=
Parameter
(
B
)
layer
.
chunk_weight
()
out
=
layer
(
A
)
A_master
=
A_master
.
clone
()
W_master
=
W_master
.
clone
()
B_master
=
B_master
.
clone
()
C_master
=
torch
.
matmul
(
A_master
,
W_master
.
transpose
(
0
,
1
))
+
B_master
C
=
C_master
.
clone
()
check_equal
(
out
,
C
)
print_rank_0
(
'linear_row forward: pass'
)
tests/test_layers/test_1d/test_1d.py
View file @
c2947dad
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from
functools
import
partial
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.initialize
import
launch
from
colossalai.utils
import
free_port
from
colossalai.testing
import
rerun_if_address_is_in_use
from
checks_1d.check_layer_1d
import
*
CONFIG
=
dict
(
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
4
,
mode
=
'1d'
)),)
def
check_layer
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
check_linear_col
()
check_linear_row
()
check_embed
()
check_vocab_parallel_embed
()
check_classifier_no_given_weight
()
check_vocab_parallel_classifier_no_given_weight
()
check_classifier_given_embed_weight
()
check_vocab_parallel_classifier_given_embed_weight
()
check_vocab_parallel_loss
()
gpc
.
destroy
()
torch
.
cuda
.
empty_cache
()
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_1d
():
world_size
=
4
run_func
=
partial
(
check_layer
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_1d
()
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from
functools
import
partial
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
checks_1d.check_layer_1d
import
*
from
colossalai.core
import
global_context
as
gpc
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
CONFIG
=
dict
(
parallel
=
dict
(
pipeline
=
dict
(
size
=
1
),
tensor
=
dict
(
size
=
4
,
mode
=
'1d'
)),)
def
check_layer
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
check_linear_col
()
check_linear_row
()
check_embed
()
check_vocab_parallel_embed
()
check_classifier_no_given_weight
()
check_vocab_parallel_classifier_no_given_weight
()
check_classifier_given_embed_weight
()
check_vocab_parallel_classifier_given_embed_weight
()
check_vocab_parallel_loss
()
check_linear_row_stream_inference
()
gpc
.
destroy
()
torch
.
cuda
.
empty_cache
()
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_1d
():
world_size
=
4
run_func
=
partial
(
check_layer
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_1d
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment