Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0b8161fa
Commit
0b8161fa
authored
Oct 26, 2022
by
kurisusnowdeng
Committed by
アマデウス
Nov 02, 2022
Browse files
updated tp layers
parent
cb5a587e
Changes
13
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
643 additions
and
291 deletions
+643
-291
colossalai/constants.py
colossalai/constants.py
+2
-0
colossalai/context/parallel_mode.py
colossalai/context/parallel_mode.py
+2
-0
colossalai/context/process_group_initializer/initializer_3d.py
...salai/context/process_group_initializer/initializer_3d.py
+111
-1
colossalai/global_variables.py
colossalai/global_variables.py
+8
-2
colossalai/nn/layer/parallel_1d/_operation.py
colossalai/nn/layer/parallel_1d/_operation.py
+51
-0
colossalai/nn/layer/parallel_1d/layers.py
colossalai/nn/layer/parallel_1d/layers.py
+24
-5
colossalai/nn/layer/parallel_3d/_operation.py
colossalai/nn/layer/parallel_3d/_operation.py
+223
-150
colossalai/nn/layer/parallel_3d/_utils.py
colossalai/nn/layer/parallel_3d/_utils.py
+69
-20
colossalai/nn/layer/parallel_3d/layers.py
colossalai/nn/layer/parallel_3d/layers.py
+110
-59
docker/Dockerfile
docker/Dockerfile
+3
-3
tests/test_layers/test_3d/checks_3d/check_layer_3d.py
tests/test_layers/test_3d/checks_3d/check_layer_3d.py
+35
-44
tests/test_layers/test_3d/checks_3d/common.py
tests/test_layers/test_3d/checks_3d/common.py
+3
-3
tests/test_layers/test_3d/test_3d.py
tests/test_layers/test_3d/test_3d.py
+2
-4
No files found.
colossalai/constants.py
View file @
0b8161fa
...
...
@@ -23,6 +23,8 @@ INITIALIZER_MAPPING = {
INPUT_GROUP_3D
=
'input_group_3d'
WEIGHT_GROUP_3D
=
'weight_group_3d'
OUTPUT_GROUP_3D
=
'output_group_3d'
INPUT_X_WEIGHT_3D
=
'input_x_weight_group_3d'
OUTPUT_X_WEIGHT_3D
=
'output_x_weight_group_3d'
# Attributes of tensor parallel parameters
IS_TENSOR_PARALLEL
=
'is_tensor_parallel'
...
...
colossalai/context/parallel_mode.py
View file @
0b8161fa
...
...
@@ -39,6 +39,8 @@ class ParallelMode(Enum):
PARALLEL_3D_INPUT
=
'3d_input'
PARALLEL_3D_WEIGHT
=
'3d_weight'
PARALLEL_3D_OUTPUT
=
'3d_output'
PARALLEL_3D_INPUT_X_WEIGHT
=
"3d_input_x_weight"
PARALLEL_3D_OUTPUT_X_WEIGHT
=
"3d_output_x_weight"
# 2.5D parallel
PARALLEL_2P5D_ROW
=
'2p5d_row'
...
...
colossalai/context/process_group_initializer/initializer_3d.py
View file @
0b8161fa
...
...
@@ -176,6 +176,112 @@ class Initializer_3D_Output(ProcessGroupInitializer):
return
local_rank
,
group_world_size
,
process_group
,
cpu_group
,
ranks_in_group
,
mode
class
Initializer_3D_InputxWeight
(
ProcessGroupInitializer
):
"""3D tensor parallel initialization among input.
Args:
num_group (int): The number of all tensor groups.
depth (int): Depth of 3D parallelism.
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def
__init__
(
self
,
num_group
:
int
,
depth
:
int
,
*
args
):
super
().
__init__
(
*
args
)
self
.
num_group
=
num_group
self
.
depth
=
depth
def
init_dist_group
(
self
):
"""Initialize 3D tensor parallel groups among input, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
3D tensor parallelism's information among input in a tuple.
"""
local_rank
=
None
ranks_in_group
=
None
process_group
=
None
cpu_group
=
None
group_world_size
=
None
mode
=
ParallelMode
.
PARALLEL_3D_INPUT_X_WEIGHT
env
.
input_x_weight_group_3d
=
mode
for
h
in
range
(
self
.
num_group
):
for
k
in
range
(
self
.
depth
):
ranks
=
[
h
*
self
.
depth
**
3
+
i
+
self
.
depth
*
(
j
+
self
.
depth
*
k
)
for
j
in
range
(
self
.
depth
)
for
i
in
range
(
self
.
depth
)
]
group
=
dist
.
new_group
(
ranks
)
group_cpu
=
dist
.
new_group
(
ranks
,
backend
=
'gloo'
)
if
dist
.
get_backend
()
!=
'gloo'
else
group
if
self
.
rank
in
ranks
:
local_rank
=
ranks
.
index
(
self
.
rank
)
group_world_size
=
len
(
ranks
)
process_group
=
group
cpu_group
=
group_cpu
ranks_in_group
=
ranks
return
local_rank
,
group_world_size
,
process_group
,
cpu_group
,
ranks_in_group
,
mode
class
Initializer_3D_OutputxWeight
(
ProcessGroupInitializer
):
"""3D tensor parallel initialization among input.
Args:
num_group (int): The number of all tensor groups.
depth (int): Depth of 3D parallelism.
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def
__init__
(
self
,
num_group
:
int
,
depth
:
int
,
*
args
):
super
().
__init__
(
*
args
)
self
.
num_group
=
num_group
self
.
depth
=
depth
def
init_dist_group
(
self
):
"""Initialize 3D tensor parallel groups among input, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
3D tensor parallelism's information among input in a tuple.
"""
local_rank
=
None
ranks_in_group
=
None
process_group
=
None
cpu_group
=
None
group_world_size
=
None
mode
=
ParallelMode
.
PARALLEL_3D_OUTPUT_X_WEIGHT
env
.
output_x_weight_group_3d
=
mode
for
h
in
range
(
self
.
num_group
):
for
j
in
range
(
self
.
depth
):
ranks
=
[
h
*
self
.
depth
**
3
+
i
+
self
.
depth
*
(
j
+
self
.
depth
*
k
)
for
k
in
range
(
self
.
depth
)
for
i
in
range
(
self
.
depth
)
]
group
=
dist
.
new_group
(
ranks
)
group_cpu
=
dist
.
new_group
(
ranks
,
backend
=
'gloo'
)
if
dist
.
get_backend
()
!=
'gloo'
else
group
if
self
.
rank
in
ranks
:
local_rank
=
ranks
.
index
(
self
.
rank
)
group_world_size
=
len
(
ranks
)
process_group
=
group
cpu_group
=
group_cpu
ranks_in_group
=
ranks
return
local_rank
,
group_world_size
,
process_group
,
cpu_group
,
ranks_in_group
,
mode
@
DIST_GROUP_INITIALIZER
.
register_module
class
Initializer_3D
(
ProcessGroupInitializer
):
"""Serve as the single entry point to 3D parallel initialization.
...
...
@@ -200,6 +306,8 @@ class Initializer_3D(ProcessGroupInitializer):
self
.
input_initializer
=
Initializer_3D_Input
(
self
.
num_group
,
self
.
depth
,
*
args
)
self
.
weight_initializer
=
Initializer_3D_Weight
(
self
.
num_group
,
self
.
depth
,
*
args
)
self
.
output_initializer
=
Initializer_3D_Output
(
self
.
num_group
,
self
.
depth
,
*
args
)
self
.
input_x_weight_initializer
=
Initializer_3D_InputxWeight
(
self
.
num_group
,
self
.
depth
,
*
args
)
self
.
output_x_weight_initializer
=
Initializer_3D_OutputxWeight
(
self
.
num_group
,
self
.
depth
,
*
args
)
def
init_dist_group
(
self
):
"""Initialize 3D tensor parallel groups, and assign local_ranks and groups to each gpu.
...
...
@@ -211,6 +319,8 @@ class Initializer_3D(ProcessGroupInitializer):
parallel_setting
=
[
self
.
input_initializer
.
init_dist_group
(),
self
.
weight_initializer
.
init_dist_group
(),
self
.
output_initializer
.
init_dist_group
()
self
.
output_initializer
.
init_dist_group
(),
self
.
input_x_weight_initializer
.
init_dist_group
(),
self
.
output_x_weight_initializer
.
init_dist_group
()
]
return
parallel_setting
colossalai/global_variables.py
View file @
0b8161fa
...
...
@@ -22,7 +22,9 @@ class TensorParallelEnv(object):
depth_3d
:
int
=
None
,
input_group_3d
=
None
,
weight_group_3d
=
None
,
output_group_3d
=
None
):
output_group_3d
=
None
,
input_x_weight_group_3d
=
None
,
output_x_weight_group_3d
=
None
):
self
.
mode
=
mode
self
.
vocab_parallel
=
vocab_parallel
self
.
parallel_input_1d
=
parallel_input_1d
...
...
@@ -33,6 +35,8 @@ class TensorParallelEnv(object):
self
.
input_group_3d
=
input_group_3d
self
.
weight_group_3d
=
weight_group_3d
self
.
output_group_3d
=
output_group_3d
self
.
input_x_weight_group_3d
=
input_x_weight_group_3d
self
.
output_x_weight_group_3d
=
output_x_weight_group_3d
def
save
(
self
):
return
dict
(
mode
=
self
.
mode
,
...
...
@@ -44,7 +48,9 @@ class TensorParallelEnv(object):
depth_3d
=
self
.
depth_3d
,
input_group_3d
=
self
.
input_group_3d
,
weight_group_3d
=
self
.
weight_group_3d
,
output_group_3d
=
self
.
output_group_3d
)
output_group_3d
=
self
.
output_group_3d
,
input_x_weight_group_3d
=
self
.
input_x_weight_group_3d
,
output_x_weight_group_3d
=
self
.
output_x_weight_group_3d
)
tensor_parallel_env
=
TensorParallelEnv
()
colossalai/nn/layer/parallel_1d/_operation.py
View file @
0b8161fa
import
torch
import
torch.distributed
as
dist
from
colossalai.core
import
global_context
as
gpc
try
:
import
fused_mix_prec_layer_norm_cuda
...
...
@@ -43,3 +45,52 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function):
weight_
,
bias_
,
ctx
.
eps
)
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
class
LinearWithAsyncCommunication
(
torch
.
autograd
.
Function
):
"""
Linear layer execution with asynchronous communication in backprop.
"""
@
staticmethod
def
forward
(
ctx
,
input_
,
weight
,
bias
,
parallel_mode
,
async_grad_allreduce
):
ctx
.
save_for_backward
(
input_
,
weight
)
ctx
.
use_bias
=
bias
is
not
None
ctx
.
parallel_mode
=
parallel_mode
ctx
.
async_grad_allreduce
=
async_grad_allreduce
output
=
torch
.
matmul
(
input_
,
weight
.
t
())
if
bias
is
not
None
:
output
=
output
+
bias
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
total_input
=
input
grad_input
=
grad_output
.
matmul
(
weight
)
# Convert the tensor shapes to 2D for execution compatibility
grad_output
=
grad_output
.
view
(
grad_output
.
shape
[
0
]
*
grad_output
.
shape
[
1
],
grad_output
.
shape
[
2
])
total_input
=
total_input
.
view
(
total_input
.
shape
[
0
]
*
total_input
.
shape
[
1
],
total_input
.
shape
[
2
])
if
ctx
.
async_grad_allreduce
:
# Asynchronous all-reduce
handle
=
dist
.
all_reduce
(
grad_input
,
group
=
gpc
.
get_group
(
ctx
.
parallel_mode
),
async_op
=
True
)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
grad_weight
=
grad_output
.
t
().
matmul
(
total_input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
if
ctx
.
async_grad_allreduce
:
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
def
linear_with_async_comm
(
input_
,
weight
,
bias
,
parallel_mode
,
async_grad_allreduce
):
return
LinearWithAsyncCommunication
.
apply
(
input_
,
weight
,
bias
,
parallel_mode
,
async_grad_allreduce
)
colossalai/nn/layer/parallel_1d/layers.py
View file @
0b8161fa
...
...
@@ -20,12 +20,12 @@ from colossalai.utils.cuda import get_current_device
from
torch
import
Tensor
from
torch.nn.parameter
import
Parameter
from
..vanilla
import
VanillaPatchEmbedding
,
VanillaLayerNorm
from
..base_layer
import
ParallelLayer
from
..colossalai_layer._utils
import
ColossalaiModule
from
..utils
import
divide
,
set_tensor_parallel_attribute_by_partition
from
._utils
import
(
gather_forward_split_backward
,
get_parallel_input
,
reduce_grad
,
reduce_input
,
set_parallel_input
,
split_forward_gather_backward
)
from
._operation
import
linear_with_async_comm
@
LAYERS
.
register_module
...
...
@@ -96,8 +96,25 @@ class LayerNorm1D(ColossalaiModule):
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
_fast_ln_supported_sizes
=
[
1024
,
1536
,
2048
,
2304
,
3072
,
3840
,
4096
,
5120
,
6144
,
8192
,
10240
,
12288
,
12800
,
15360
,
16384
,
18432
,
20480
,
24576
,
25600
,
30720
,
32768
,
40960
,
49152
,
65536
]
def
__init__
(
self
,
normalized_shape
:
int
,
eps
=
1e-05
,
bias
=
True
,
dtype
=
None
):
norm
=
VanillaLayerNorm
(
normalized_shape
,
eps
=
eps
,
bias
=
bias
,
dtype
=
dtype
)
from
apex.normalization
import
FusedLayerNorm
fast_ln_installed
=
False
try
:
from
apex.contrib.layer_norm.layer_norm
import
FastLayerNorm
fast_ln_installed
=
True
except
ImportError
:
pass
if
fast_ln_installed
and
normalized_shape
in
self
.
_fast_ln_supported_sizes
:
norm
=
FastLayerNorm
(
normalized_shape
,
eps
=
eps
).
to
(
dtype
)
else
:
norm
=
FusedLayerNorm
(
normalized_shape
,
eps
=
eps
).
to
(
dtype
)
super
().
__init__
(
norm
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
):
...
...
@@ -519,11 +536,12 @@ class Linear1D_Col(ParallelLayer):
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'
.
format
(
input_
.
shape
,
self
.
weight
.
shape
,
self
.
weight
.
shape
[
-
1
])
# Set up backprop all-reduce.
input_parallel
=
reduce_grad
(
input_
,
ParallelMode
.
PARALLEL_1D
)
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
input_parallel
=
input_
# Matrix multiply.
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
bias
)
# output_parallel = F.linear(input_parallel, self.weight, bias)
output_parallel
=
linear_with_async_comm
(
input_parallel
,
self
.
weight
,
bias
,
ParallelMode
.
PARALLEL_1D
,
True
)
if
self
.
gather_output
:
# All-gather across the partitions.
output
=
gather_forward_split_backward
(
output_parallel
,
ParallelMode
.
PARALLEL_1D
,
dim
=-
1
)
...
...
@@ -665,6 +683,7 @@ class Linear1D_Row(ParallelLayer):
input_
=
split_forward_gather_backward
(
input_
,
ParallelMode
.
PARALLEL_1D
,
dim
=-
1
)
output_parallel
=
F
.
linear
(
input_
,
self
.
weight
)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output
=
reduce_input
(
output_parallel
,
ParallelMode
.
PARALLEL_1D
)
if
not
self
.
skip_bias_add
:
...
...
colossalai/nn/layer/parallel_3d/_operation.py
View file @
0b8161fa
This diff is collapsed.
Click to expand it.
colossalai/nn/layer/parallel_3d/_utils.py
View file @
0b8161fa
from
colossalai.constants
import
INPUT_GROUP_3D
,
WEIGHT_GROUP_3D
,
OUTPUT_GROUP_3D
from
collections
import
OrderedDict
from
functools
import
partial
import
torch
from
torch
import
Tensor
from
colossalai.constants
import
INPUT_GROUP_3D
,
INPUT_X_WEIGHT_3D
,
OUTPUT_GROUP_3D
,
OUTPUT_X_WEIGHT_3D
,
WEIGHT_GROUP_3D
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.global_variables
import
tensor_parallel_env
as
env
from
torch
import
Tensor
def
get_depth_from_env
()
->
int
:
...
...
@@ -17,30 +22,17 @@ def get_depth_from_env() -> int:
def
get_parallel_mode_from_env
(
group
):
assert
group
in
[
INPUT_GROUP_3D
,
WEIGHT_GROUP_3D
,
OUTPUT_GROUP_3D
],
\
assert
group
in
[
INPUT_GROUP_3D
,
WEIGHT_GROUP_3D
,
OUTPUT_GROUP_3D
,
INPUT_X_WEIGHT_3D
,
OUTPUT_X_WEIGHT_3D
],
\
f
'
{
group
}
is not valid for 3D tensor parallelism.'
return
getattr
(
env
,
group
)
def
get_last_group
(
a
,
b
):
mapping
=
{
ParallelMode
.
PARALLEL_3D_INPUT
:
'A'
,
ParallelMode
.
PARALLEL_3D_WEIGHT
:
'B'
,
ParallelMode
.
PARALLEL_3D_OUTPUT
:
'C'
,
}
res
=
chr
(
ord
(
'A'
)
+
ord
(
'B'
)
+
ord
(
'C'
)
-
ord
(
mapping
[
a
])
-
ord
(
mapping
[
b
]))
if
res
==
'A'
:
return
ParallelMode
.
PARALLEL_3D_INPUT
elif
res
==
'B'
:
return
ParallelMode
.
PARALLEL_3D_WEIGHT
elif
res
==
'C'
:
return
ParallelMode
.
PARALLEL_3D_OUTPUT
def
swap_in_out_group
():
env
.
input_group_3d
,
env
.
output_group_3d
=
env
.
output_group_3d
,
env
.
input_group_3d
env
.
input_x_weight_group_3d
,
env
.
output_x_weight_group_3d
=
(
env
.
output_x_weight_group_3d
,
env
.
input_x_weight_group_3d
,
)
def
dbg_check_shape
(
tensor
:
Tensor
,
shape
:
tuple
):
...
...
@@ -49,3 +41,60 @@ def dbg_check_shape(tensor: Tensor, shape: tuple):
print
(
tensor
.
shape
)
assert
tensor
.
shape
==
shape
,
\
'{} does not match {}'
.
format
(
tensor
.
shape
,
shape
)
class
AsyncGradientBucket
(
object
):
def
__init__
(
self
):
self
.
bucket
=
OrderedDict
()
def
__len__
(
self
):
return
len
(
self
.
bucket
)
def
push
(
self
,
async_op
,
grad_tensor
,
param_id
):
self
.
bucket
[
param_id
]
=
tuple
((
async_op
,
grad_tensor
))
return
torch
.
zeros_like
(
grad_tensor
,
dtype
=
grad_tensor
.
dtype
,
device
=
grad_tensor
.
device
)
def
pop
(
self
,
param_id
):
grad
=
None
if
param_id
in
self
.
bucket
:
op
,
grad
=
self
.
bucket
.
pop
(
param_id
)
if
op
is
not
None
:
op
.
wait
()
return
grad
def
synchronize
(
self
,
params
):
for
p
in
params
:
i
=
id
(
p
)
if
i
in
self
.
bucket
:
op
,
grad
=
self
.
bucket
.
pop
(
i
)
if
op
is
not
None
:
op
.
wait
()
p
.
grad
.
add_
(
grad
)
_async_grad_bucket
=
AsyncGradientBucket
()
def
push_async_grad
(
op
,
grad
,
param_id
):
return
_async_grad_bucket
.
push
(
op
,
grad
,
param_id
)
def
pop_async_grad
(
param_id
):
return
_async_grad_bucket
.
pop
(
param_id
)
def
_async_grad_hook
(
grad
,
param_id
):
grad
.
add_
(
pop_async_grad
(
param_id
))
return
grad
def
register_async_grad_hook
(
param
):
param
.
register_hook
(
partial
(
_async_grad_hook
,
param_id
=
id
(
param
)))
def
synchronize
(
params
=
list
()):
_async_grad_bucket
.
synchronize
(
params
)
torch
.
cuda
.
default_stream
().
synchronize
()
if
len
(
_async_grad_bucket
)
>
0
:
raise
RuntimeError
(
f
"
{
len
(
_async_grad_bucket
)
}
asynchronous gradient(s) not collected."
)
colossalai/nn/layer/parallel_3d/layers.py
View file @
0b8161fa
...
...
@@ -6,7 +6,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
colossalai.communication
import
all_reduce
,
broadcast
from
colossalai.constants
import
INPUT_GROUP_3D
,
WEIGHT_GROUP_3D
from
colossalai.constants
import
INPUT_GROUP_3D
,
INPUT_X_WEIGHT_3D
,
OUTPUT_GROUP_3D
,
OUTPUT_X_WEIGHT_3D
,
WEIGHT_GROUP_3D
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.core
import
global_context
as
gpc
from
colossalai.global_variables
import
tensor_parallel_env
as
env
...
...
@@ -20,9 +20,9 @@ from torch import Tensor
from
torch.nn
import
Parameter
from
..utils
import
divide
,
set_tensor_parallel_attribute_by_partition
,
to_2tuple
from
._operation
import
(
all_gather_tensor_3d
,
broadcast_weight_3d_from_diagonal
,
classifier_3d
,
layernorm_3d
,
linear_3d
,
reduce_scatter_tensor_3d
,
split_tensor_3d
)
from
._utils
import
get_depth_from_env
,
get_last_group
,
get_parallel_mode_from_env
,
swap_in_out_group
from
._operation
import
(
all_gather_tensor_3d
,
classifier_3d
,
vocab_parallel_
classifier_3d
,
layernorm_3d
,
linear_3d
,
reduce_scatter_tensor_3d
,
split_tensor_3d
,
split_batch_3d
)
from
._utils
import
get_depth_from_env
,
get_parallel_mode_from_env
,
swap_in_out_group
,
register_async_grad_hook
@
LAYERS
.
register_module
...
...
@@ -45,7 +45,8 @@ class LayerNorm3D(ParallelLayer):
super
().
__init__
()
self
.
input_parallel_mode
=
get_parallel_mode_from_env
(
INPUT_GROUP_3D
)
self
.
weight_parallel_mode
=
get_parallel_mode_from_env
(
WEIGHT_GROUP_3D
)
self
.
output_parallel_mode
=
get_last_group
(
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
)
self
.
output_parallel_mode
=
get_parallel_mode_from_env
(
OUTPUT_GROUP_3D
)
self
.
input_x_weight_parallel_mode
=
get_parallel_mode_from_env
(
INPUT_X_WEIGHT_3D
)
self
.
depth
=
get_depth_from_env
()
self
.
normalized_shape
=
normalized_shape
self
.
normalized_shape_per_partition
=
divide
(
normalized_shape
,
self
.
depth
)
...
...
@@ -58,6 +59,7 @@ class LayerNorm3D(ParallelLayer):
else
:
self
.
bias
=
None
self
.
variance_epsilon
=
eps
self
.
reset_parameters
()
self
.
_set_tensor_parallel_attributes
()
def
_set_tensor_parallel_attributes
(
self
)
->
None
:
...
...
@@ -67,8 +69,10 @@ class LayerNorm3D(ParallelLayer):
def
reset_parameters
(
self
)
->
None
:
init
.
ones_
()(
self
.
weight
)
register_async_grad_hook
(
self
.
weight
)
if
self
.
bias
is
not
None
:
init
.
zeros_
()(
self
.
bias
)
register_async_grad_hook
(
self
.
bias
)
def
_load_from_global_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
...
...
@@ -134,8 +138,17 @@ class LayerNorm3D(ParallelLayer):
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
return
layernorm_3d
(
input_
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
variance_epsilon
,
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
,
self
.
output_parallel_mode
)
return
layernorm_3d
(
input_
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
variance_epsilon
,
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
,
self
.
output_parallel_mode
,
self
.
input_x_weight_parallel_mode
,
)
@
LAYERS
.
register_module
...
...
@@ -161,6 +174,7 @@ class Linear3D(ParallelLayer):
out_features
:
int
,
bias
:
bool
=
True
,
dtype
:
torch
.
dtype
=
None
,
skip_bias_add
:
bool
=
False
,
weight_initializer
:
Callable
=
init
.
kaiming_uniform_
(
a
=
math
.
sqrt
(
5
)),
bias_initializer
:
Callable
=
init
.
xavier_uniform_
(
a
=
1
,
scale
=
1
)):
super
().
__init__
()
...
...
@@ -168,8 +182,10 @@ class Linear3D(ParallelLayer):
self
.
out_features
=
out_features
self
.
input_parallel_mode
=
get_parallel_mode_from_env
(
INPUT_GROUP_3D
)
self
.
weight_parallel_mode
=
get_parallel_mode_from_env
(
WEIGHT_GROUP_3D
)
self
.
output_parallel_mode
=
get_last_group
(
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
)
self
.
output_parallel_mode
=
get_parallel_mode_from_env
(
OUTPUT_GROUP_3D
)
self
.
output_x_weight_parallel_mode
=
get_parallel_mode_from_env
(
OUTPUT_X_WEIGHT_3D
)
self
.
depth
=
get_depth_from_env
()
self
.
skip_bias_add
=
skip_bias_add
self
.
in_features_per_partition
=
divide
(
in_features
,
self
.
depth
)
self
.
out_features_per_partition
=
divide
(
out_features
,
self
.
depth
**
2
)
self
.
bias_features_per_partition
=
divide
(
out_features
,
self
.
depth
)
...
...
@@ -194,18 +210,23 @@ class Linear3D(ParallelLayer):
if
self
.
bias
is
not
None
:
set_tensor_parallel_attribute_by_partition
(
self
.
bias
,
self
.
depth
)
def
_sync_grad_hook
(
self
,
grad
)
->
Tensor
:
grad
=
all_reduce
(
grad
.
clone
(),
self
.
output_x_weight_parallel_mode
)
return
grad
def
reset_parameters
(
self
,
weight_initializer
,
bias_initializer
)
->
None
:
with
seed
(
ParallelMode
.
TENSOR
):
fan_in
,
fan_out
=
self
.
in_features
,
self
.
out_features
weight_initializer
(
self
.
weight
,
fan_in
=
fan_in
,
fan_out
=
fan_out
)
register_async_grad_hook
(
self
.
weight
)
if
self
.
bias
is
not
None
:
bias_initializer
(
self
.
bias
,
fan_in
=
fan_in
)
weight_src_rank
=
gpc
.
get_ranks_in_group
(
self
.
weight_parallel_mode
)[
0
]
output_src_rank
=
gpc
.
get_ranks_in_group
(
self
.
output_parallel_mode
)[
0
]
broadcast
(
self
.
bias
,
weight_src_rank
,
self
.
weight_parallel_mode
)
broadcast
(
self
.
bias
,
output_src_rank
,
self
.
output_parallel_mode
)
broadcast
(
self
.
bias
,
gpc
.
get_ranks_in_group
(
self
.
output_
x_weight_
parallel_mode
)[
0
]
,
self
.
output_x_
weight_parallel_mode
)
self
.
bias
.
register_hook
(
self
.
_sync_grad_hook
)
def
_load_from_global_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
...
...
@@ -324,8 +345,20 @@ class Linear3D(ParallelLayer):
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
return
linear_3d
(
input_
,
self
.
weight
,
self
.
bias
,
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
,
self
.
output_parallel_mode
)
output
=
linear_3d
(
input_
,
self
.
weight
,
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
,
self
.
output_parallel_mode
,
)
if
not
self
.
skip_bias_add
:
if
self
.
bias
is
not
None
:
output
=
output
+
self
.
bias
return
output
else
:
return
output
,
self
.
bias
@
LAYERS
.
register_module
...
...
@@ -360,7 +393,7 @@ class Classifier3D(ParallelLayer):
self
.
num_classes
=
num_classes
self
.
input_parallel_mode
=
get_parallel_mode_from_env
(
INPUT_GROUP_3D
)
self
.
weight_parallel_mode
=
get_parallel_mode_from_env
(
WEIGHT_GROUP_3D
)
self
.
output_parallel_mode
=
get_
last_group
(
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
)
self
.
output_parallel_mode
=
get_
parallel_mode_from_env
(
OUTPUT_GROUP_3D
)
self
.
depth
=
get_depth_from_env
()
self
.
in_features_per_partition
=
divide
(
in_features
,
self
.
depth
)
...
...
@@ -386,19 +419,17 @@ class Classifier3D(ParallelLayer):
def
reset_parameters
(
self
,
weight_initializer
,
bias_initializer
)
->
None
:
with
seed
(
ParallelMode
.
TENSOR
):
fan_in
,
fan_out
=
self
.
in_features
,
self
.
num_classes
weight_src_rank
=
gpc
.
get_ranks_in_group
(
self
.
weight_parallel_mode
)[
0
]
output_src_rank
=
gpc
.
get_ranks_in_group
(
self
.
output_parallel_mode
)[
0
]
input_src_rank
=
gpc
.
get_ranks_in_group
(
self
.
input_parallel_mode
)[
0
]
if
self
.
has_weight
:
weight_initializer
(
self
.
weight
,
fan_in
=
fan_in
,
fan_out
=
fan_out
)
broadcast
(
self
.
weight
,
weight_src_rank
,
self
.
weight_parallel_mode
)
broadcast
(
self
.
weight
,
gpc
.
get_ranks_in_group
(
self
.
weight_parallel_mode
)[
0
],
self
.
weight_parallel_mode
)
register_async_grad_hook
(
self
.
weight
)
if
self
.
bias
is
not
None
:
bias_initializer
(
self
.
bias
,
fan_in
=
fan_in
)
broadcast
(
self
.
bias
,
weight_src_rank
,
self
.
weight_parallel_mode
)
broadcast
(
self
.
bias
,
output_src_rank
,
self
.
output_parallel_mode
)
broadcast
(
self
.
bias
,
input_src_rank
,
self
.
input_parallel_mode
)
broadcast
(
self
.
bias
,
gpc
.
get_ranks_in_group
(
ParallelMode
.
TENSOR
)[
0
],
ParallelMode
.
TENSOR
)
register_async_grad_hook
(
self
.
bias
)
def
_load_from_global_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
...
...
@@ -468,8 +499,14 @@ class Classifier3D(ParallelLayer):
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
return
classifier_3d
(
input_
,
self
.
weight
,
self
.
bias
,
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
,
self
.
output_parallel_mode
)
return
classifier_3d
(
input_
,
self
.
weight
,
self
.
bias
,
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
,
self
.
output_parallel_mode
,
)
@
LAYERS
.
register_module
...
...
@@ -504,7 +541,8 @@ class VocabParallelClassifier3D(ParallelLayer):
self
.
num_classes
=
num_classes
self
.
input_parallel_mode
=
get_parallel_mode_from_env
(
INPUT_GROUP_3D
)
self
.
weight_parallel_mode
=
get_parallel_mode_from_env
(
WEIGHT_GROUP_3D
)
self
.
output_parallel_mode
=
get_last_group
(
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
)
self
.
output_parallel_mode
=
get_parallel_mode_from_env
(
OUTPUT_GROUP_3D
)
self
.
output_x_weight_parallel_mode
=
get_parallel_mode_from_env
(
OUTPUT_X_WEIGHT_3D
)
self
.
depth
=
get_depth_from_env
()
self
.
in_features_per_partition
=
divide
(
in_features
,
self
.
depth
)
self
.
out_features_per_partition
=
divide
(
num_classes
,
self
.
depth
**
2
)
...
...
@@ -544,12 +582,14 @@ class VocabParallelClassifier3D(ParallelLayer):
if
self
.
has_weight
:
weight_initializer
(
self
.
weight
,
fan_in
=
fan_in
,
fan_out
=
fan_out
)
register_async_grad_hook
(
self
.
weight
)
if
self
.
bias
is
not
None
:
bias_initializer
(
self
.
bias
,
fan_in
=
fan_in
)
weight_src_rank
=
gpc
.
get_ranks_in_group
(
self
.
weight_parallel_mode
)[
0
]
output_src_rank
=
gpc
.
get_ranks_in_group
(
self
.
output_parallel_mode
)[
0
]
broadcast
(
self
.
bias
,
weight_src_rank
,
self
.
weight_parallel_mode
)
broadcast
(
self
.
bias
,
output_sr
c_ra
nk
,
self
.
output_parallel_mode
)
broadcast
(
self
.
bias
,
gpc
.
get_ranks_in_group
(
self
.
output_
x_weight_
parallel_mode
)[
0
]
,
self
.
output_x_
weight_parallel_mode
)
register_asyn
c_
g
ra
d_hook
(
self
.
bias
)
def
_load_from_global_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
...
...
@@ -668,8 +708,14 @@ class VocabParallelClassifier3D(ParallelLayer):
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
return
linear_3d
(
input_
,
self
.
weight
.
transpose
(
0
,
1
),
self
.
bias
,
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
,
self
.
output_parallel_mode
)
return
vocab_parallel_classifier_3d
(
input_
,
self
.
weight
,
self
.
bias
,
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
,
self
.
output_parallel_mode
,
)
@
LAYERS
.
register_module
...
...
@@ -708,12 +754,16 @@ class PatchEmbedding3D(ParallelLayer):
self
.
depth
=
get_depth_from_env
()
self
.
input_parallel_mode
=
get_parallel_mode_from_env
(
INPUT_GROUP_3D
)
self
.
weight_parallel_mode
=
get_parallel_mode_from_env
(
WEIGHT_GROUP_3D
)
self
.
output_parallel_mode
=
get_last_group
(
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
)
self
.
patch_size
=
to_2tuple
(
patch_size
)
grid_size
=
to_2tuple
(
img_size
//
patch_size
)
num_patches
=
grid_size
[
0
]
*
grid_size
[
1
]
self
.
output_parallel_mode
=
get_parallel_mode_from_env
(
OUTPUT_GROUP_3D
)
self
.
input_x_weight_parallel_mode
=
get_parallel_mode_from_env
(
INPUT_X_WEIGHT_3D
)
img_size
=
to_2tuple
(
img_size
)
patch_size
=
to_2tuple
(
patch_size
)
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
grid_size
=
(
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
])
self
.
num_patches
=
self
.
grid_size
[
0
]
*
self
.
grid_size
[
1
]
self
.
embed_size
=
embed_size
embed_size_per_partition
=
divide
(
embed_size
,
self
.
depth
)
embed_size_per_partition
=
embed_size
//
self
.
depth
self
.
flatten
=
flatten
self
.
weight
=
nn
.
Parameter
(
...
...
@@ -725,7 +775,7 @@ class PatchEmbedding3D(ParallelLayer):
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
((
1
,
1
,
embed_size_per_partition
),
device
=
get_current_device
(),
dtype
=
dtype
))
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
((
1
,
num_patches
+
1
,
embed_size_per_partition
),
device
=
get_current_device
(),
dtype
=
dtype
))
torch
.
zeros
((
1
,
self
.
num_patches
+
1
,
embed_size_per_partition
),
device
=
get_current_device
(),
dtype
=
dtype
))
self
.
reset_parameters
(
weight_initializer
,
bias_initializer
,
position_embed_initializer
)
self
.
_set_tensor_parallel_attributes
()
...
...
@@ -737,8 +787,7 @@ class PatchEmbedding3D(ParallelLayer):
set_tensor_parallel_attribute_by_partition
(
self
.
pos_embed
,
self
.
depth
)
def
_sync_grad_hook
(
self
,
grad
)
->
Tensor
:
grad
=
all_reduce
(
grad
.
clone
(),
self
.
input_parallel_mode
)
grad
=
all_reduce
(
grad
,
self
.
weight_parallel_mode
)
grad
=
all_reduce
(
grad
.
clone
(),
self
.
input_x_weight_parallel_mode
)
return
grad
def
reset_parameters
(
self
,
weight_initializer
,
bias_initializer
,
position_embed_initializer
)
->
None
:
...
...
@@ -749,14 +798,10 @@ class PatchEmbedding3D(ParallelLayer):
bias_initializer
(
self
.
bias
,
fan_in
=
fan_in
)
position_embed_initializer
(
self
.
pos_embed
)
weight_src_rank
=
gpc
.
get_ranks_in_group
(
self
.
weight_parallel_mode
)[
0
]
input_src_rank
=
gpc
.
get_ranks_in_group
(
self
.
input_parallel_mode
)[
0
]
broadcast
(
self
.
weight
,
weight_src_rank
,
self
.
weight_parallel_mode
)
broadcast
(
self
.
bias
,
weight_src_rank
,
self
.
weight_parallel_mode
)
broadcast
(
self
.
pos_embed
,
weight_src_rank
,
self
.
weight_parallel_mode
)
broadcast
(
self
.
weight
,
input_src_rank
,
self
.
input_parallel_mode
)
broadcast
(
self
.
bias
,
input_src_rank
,
self
.
input_parallel_mode
)
broadcast
(
self
.
pos_embed
,
input_src_rank
,
self
.
input_parallel_mode
)
src_rank
=
gpc
.
get_ranks_in_group
(
self
.
input_x_weight_parallel_mode
)[
0
]
broadcast
(
self
.
weight
,
src_rank
,
self
.
input_x_weight_parallel_mode
)
broadcast
(
self
.
bias
,
src_rank
,
self
.
input_x_weight_parallel_mode
)
broadcast
(
self
.
pos_embed
,
src_rank
,
self
.
input_x_weight_parallel_mode
)
self
.
weight
.
register_hook
(
self
.
_sync_grad_hook
)
self
.
bias
.
register_hook
(
self
.
_sync_grad_hook
)
...
...
@@ -850,11 +895,12 @@ class PatchEmbedding3D(ParallelLayer):
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
input_
=
split_tensor_3d
(
input_
,
0
,
self
.
weight_parallel_mode
)
input_
=
split_tensor_3d
(
input_
,
0
,
self
.
input_parallel_mode
)
input_
=
split_batch_3d
(
input_
,
input_parallel_mode
=
self
.
input_parallel_mode
,
weight_parallel_mode
=
self
.
weight_parallel_mode
)
output
=
F
.
conv2d
(
input_
,
self
.
weight
,
self
.
bias
,
stride
=
self
.
patch_size
)
if
self
.
flatten
:
output
=
output
.
flatten
(
2
).
transpose
(
1
,
2
)
# BCHW -> BNC
output
=
output
.
flatten
(
2
).
transpose
(
1
,
2
)
# BCHW -> BNC
cls_token
=
self
.
cls_token
.
expand
(
output
.
shape
[
0
],
-
1
,
-
1
)
output
=
torch
.
cat
((
cls_token
,
output
),
dim
=
1
)
...
...
@@ -906,7 +952,8 @@ class Embedding3D(ParallelLayer):
self
.
depth
=
get_depth_from_env
()
self
.
input_parallel_mode
=
get_parallel_mode_from_env
(
INPUT_GROUP_3D
)
self
.
weight_parallel_mode
=
get_parallel_mode_from_env
(
WEIGHT_GROUP_3D
)
self
.
output_parallel_mode
=
get_last_group
(
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
)
self
.
output_parallel_mode
=
get_parallel_mode_from_env
(
OUTPUT_GROUP_3D
)
self
.
input_x_weight_parallel_mode
=
get_parallel_mode_from_env
(
INPUT_X_WEIGHT_3D
)
self
.
num_embeddings
=
num_embeddings
self
.
embed_dim
=
embedding_dim
...
...
@@ -924,13 +971,18 @@ class Embedding3D(ParallelLayer):
def
_set_tensor_parallel_attributes
(
self
)
->
None
:
set_tensor_parallel_attribute_by_partition
(
self
.
weight
,
self
.
depth
)
def
_sync_grad_hook
(
self
,
grad
)
->
Tensor
:
grad
=
all_reduce
(
grad
.
clone
(),
self
.
input_x_weight_parallel_mode
)
return
grad
def
reset_parameters
(
self
,
weight_initializer
)
->
None
:
with
seed
(
ParallelMode
.
TENSOR
):
fan_in
,
fan_out
=
self
.
num_embeddings
,
self
.
embed_dim
weight_initializer
(
self
.
weight
,
fan_in
=
fan_in
,
fan_out
=
fan_out
)
self
.
_fill_padding_idx_with_zero
()
weight_src_rank
=
gpc
.
get_ranks_in_group
(
self
.
weight_parallel_mode
)[
0
]
broadcast
(
self
.
weight
,
weight_src_rank
,
self
.
weight_parallel_mode
)
broadcast
(
self
.
weight
,
gpc
.
get_ranks_in_group
(
self
.
input_x_weight_parallel_mode
)[
0
],
self
.
input_x_weight_parallel_mode
)
self
.
weight
.
register_hook
(
self
.
_sync_grad_hook
)
def
_fill_padding_idx_with_zero
(
self
)
->
None
:
if
self
.
padding_idx
is
not
None
:
...
...
@@ -981,11 +1033,10 @@ class Embedding3D(ParallelLayer):
destination
.
update
(
local_state
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
input_
=
split_tensor_3d
(
input_
,
0
,
self
.
weight_parallel_mode
)
input_
=
split_tensor_3d
(
input_
,
0
,
self
.
input_parallel_mode
)
weight
=
broadcast_weight_3d_from_diagonal
(
self
.
weight
,
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
,
self
.
output_parallel_mode
)
output
=
F
.
embedding
(
input_
,
weight
,
self
.
padding_idx
,
*
self
.
embed_args
,
**
self
.
embed_kwargs
)
input_
=
split_batch_3d
(
input_
,
input_parallel_mode
=
self
.
input_parallel_mode
,
weight_parallel_mode
=
self
.
weight_parallel_mode
)
output
=
F
.
embedding
(
input_
,
self
.
weight
,
self
.
padding_idx
,
*
self
.
embed_args
,
**
self
.
embed_kwargs
)
return
output
...
...
@@ -1039,7 +1090,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
self
.
depth
=
get_depth_from_env
()
self
.
input_parallel_mode
=
get_parallel_mode_from_env
(
INPUT_GROUP_3D
)
self
.
weight_parallel_mode
=
get_parallel_mode_from_env
(
WEIGHT_GROUP_3D
)
self
.
output_parallel_mode
=
get_
last_group
(
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
)
self
.
output_parallel_mode
=
get_
parallel_mode_from_env
(
OUTPUT_GROUP_3D
)
self
.
num_embeddings_per_partition
=
divide
(
self
.
num_embeddings
,
self
.
depth
**
2
)
self
.
embed_dim_per_partition
=
divide
(
self
.
embed_dim
,
self
.
depth
)
vocab_parallel_rank
=
gpc
.
get_local_rank
(
self
.
input_parallel_mode
)
...
...
docker/Dockerfile
View file @
0b8161fa
...
...
@@ -6,12 +6,12 @@ RUN conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
# install apex
RUN
git clone https://github.com/NVIDIA/apex
&&
\
cd
apex
&&
\
pip
install
-v
--disable-pip-version-check
--no-cache-dir
--global-option
=
"--cpp_ext"
--global-option
=
"--cuda_ext"
./
pip
install
-v
--disable-pip-version-check
--no-cache-dir
--global-option
=
"--cpp_ext"
--global-option
=
"--cuda_ext"
--global-option
=
"--fast_layer_norm"
./
# install colossalai
RUN
git clone https://github.com/hpcaitech/ColossalAI.git
\
&&
cd
./ColossalAI
\
&&
pip
install
-v
--no-cache-dir
.
&&
cd
./ColossalAI
\
&&
pip
install
-v
--no-cache-dir
.
# install titans
RUN
pip
install
--no-cache-dir
titans
...
...
tests/test_layers/test_3d/checks_3d/check_layer_3d.py
View file @
0b8161fa
...
...
@@ -20,7 +20,6 @@ def check_linear():
rank
=
torch
.
distributed
.
get_rank
()
logger
=
get_dist_logger
()
device
=
get_current_device
()
dtype
=
torch
.
float32
INPUT_SIZE
=
HIDDEN_SIZE
OUTPUT_SIZE
=
2
*
HIDDEN_SIZE
...
...
@@ -32,12 +31,12 @@ def check_linear():
i
=
global_context
.
get_local_rank
(
weight_parallel_mode
)
k
=
global_context
.
get_local_rank
(
output_parallel_mode
)
layer
=
Linear3D
(
INPUT_SIZE
,
OUTPUT_SIZE
,
dtype
=
dtype
,
bias
=
True
)
layer
=
Linear3D
(
INPUT_SIZE
,
OUTPUT_SIZE
,
bias
=
True
)
layer
=
layer
.
to
(
device
)
layer_master
=
torch
.
nn
.
Linear
(
INPUT_SIZE
,
OUTPUT_SIZE
)
layer_master
=
layer_master
.
to
(
device
)
weight_master
=
layer_master
.
weight
.
data
.
transpose
(
0
,
1
)
weight_master
=
layer_master
.
weight
.
data
.
transpose
(
0
,
1
)
.
contiguous
()
torch
.
distributed
.
broadcast
(
weight_master
,
src
=
0
)
weight
=
torch
.
chunk
(
weight_master
,
DEPTH
,
dim
=
0
)[
k
]
weight
=
torch
.
chunk
(
weight
,
DEPTH
,
dim
=-
1
)[
j
]
...
...
@@ -49,7 +48,7 @@ def check_linear():
layer
.
bias
.
data
.
copy_
(
bias
)
A_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
,
INPUT_SIZE
)
A_master
=
torch
.
randn
(
A_shape
,
dtype
=
dtype
,
device
=
device
)
A_master
=
torch
.
randn
(
A_shape
,
device
=
device
)
torch
.
distributed
.
broadcast
(
A_master
,
src
=
0
)
A
=
torch
.
chunk
(
A_master
,
DEPTH
,
dim
=
0
)[
i
]
A
=
torch
.
chunk
(
A
,
DEPTH
,
dim
=-
1
)[
k
]
...
...
@@ -72,7 +71,7 @@ def check_linear():
logger
.
info
(
'Rank {} linear forward: {}'
.
format
(
rank
,
check_equal
(
out
,
C
)))
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
get_current_device
())
grad_master
=
torch
.
randn
(
grad_shape
,
device
=
get_current_device
())
torch
.
distributed
.
broadcast
(
grad_master
,
src
=
0
)
grad
=
torch
.
chunk
(
grad_master
,
DEPTH
,
dim
=
0
)[
i
]
grad
=
torch
.
chunk
(
grad
,
DEPTH
,
dim
=-
1
)[
j
]
...
...
@@ -108,7 +107,6 @@ def check_layernorm():
rank
=
torch
.
distributed
.
get_rank
()
logger
=
get_dist_logger
()
device
=
get_current_device
()
dtype
=
torch
.
float32
INPUT_SIZE
=
HIDDEN_SIZE
input_parallel_mode
=
get_parallel_mode_from_env
(
INPUT_GROUP_3D
)
...
...
@@ -119,7 +117,7 @@ def check_layernorm():
i
=
global_context
.
get_local_rank
(
weight_parallel_mode
)
k
=
global_context
.
get_local_rank
(
output_parallel_mode
)
norm
=
LayerNorm3D
(
INPUT_SIZE
,
eps
=
1e-6
,
dtype
=
dtype
)
norm
=
LayerNorm3D
(
INPUT_SIZE
,
eps
=
1e-6
)
norm
=
norm
.
to
(
device
)
norm_master
=
torch
.
nn
.
LayerNorm
(
INPUT_SIZE
,
eps
=
1e-6
)
norm_master
=
norm_master
.
to
(
device
)
...
...
@@ -134,7 +132,7 @@ def check_layernorm():
norm
.
bias
.
data
.
copy_
(
bias
)
A_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
,
INPUT_SIZE
)
A_master
=
torch
.
randn
(
A_shape
,
dtype
=
dtype
,
device
=
device
)
A_master
=
torch
.
randn
(
A_shape
,
device
=
device
)
torch
.
distributed
.
broadcast
(
A_master
,
src
=
0
)
A
=
torch
.
chunk
(
A_master
,
DEPTH
,
dim
=
0
)[
i
]
A
=
torch
.
chunk
(
A
,
DEPTH
,
dim
=-
1
)[
k
]
...
...
@@ -159,7 +157,7 @@ def check_layernorm():
logger
.
info
(
'Rank {} layernorm forward: {}'
.
format
(
rank
,
check_equal
(
out
,
C
)))
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
device
)
grad_master
=
torch
.
randn
(
grad_shape
,
device
=
device
)
torch
.
distributed
.
broadcast
(
grad_master
,
src
=
0
)
grad
=
torch
.
chunk
(
grad_master
,
DEPTH
,
dim
=
0
)[
i
]
grad
=
torch
.
chunk
(
grad
,
DEPTH
,
dim
=-
1
)[
k
]
...
...
@@ -193,7 +191,6 @@ def check_classifier_no_given_weight():
rank
=
torch
.
distributed
.
get_rank
()
logger
=
get_dist_logger
()
device
=
get_current_device
()
dtype
=
torch
.
float32
INPUT_SIZE
=
HIDDEN_SIZE
input_parallel_mode
=
get_parallel_mode_from_env
(
INPUT_GROUP_3D
)
...
...
@@ -204,10 +201,10 @@ def check_classifier_no_given_weight():
i
=
global_context
.
get_local_rank
(
weight_parallel_mode
)
k
=
global_context
.
get_local_rank
(
output_parallel_mode
)
layer
=
Classifier3D
(
INPUT_SIZE
,
NUM_CLASSES
,
dtype
=
dtype
,
bias
=
True
)
layer
=
Classifier3D
(
INPUT_SIZE
,
NUM_CLASSES
,
bias
=
True
)
layer
=
layer
.
to
(
device
)
layer_master
=
VanillaClassifier
(
INPUT_SIZE
,
NUM_CLASSES
,
bias
=
True
,
dtype
=
dtype
)
layer_master
=
VanillaClassifier
(
INPUT_SIZE
,
NUM_CLASSES
,
bias
=
True
)
layer_master
=
layer_master
.
to
(
device
)
weight_master
=
layer_master
.
weight
.
data
...
...
@@ -219,7 +216,7 @@ def check_classifier_no_given_weight():
layer
.
bias
.
data
.
copy_
(
bias_master
)
A_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
,
INPUT_SIZE
)
A_master
=
torch
.
randn
(
A_shape
,
dtype
=
dtype
,
device
=
device
)
A_master
=
torch
.
randn
(
A_shape
,
device
=
device
)
torch
.
distributed
.
broadcast
(
A_master
,
src
=
0
)
A
=
torch
.
chunk
(
A_master
,
DEPTH
,
dim
=
0
)[
i
]
A
=
torch
.
chunk
(
A
,
DEPTH
,
dim
=-
1
)[
k
]
...
...
@@ -242,7 +239,7 @@ def check_classifier_no_given_weight():
logger
.
info
(
'Rank {} classifier (no given weight) forward: {}'
.
format
(
rank
,
check_equal
(
out
,
C
)))
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
get_current_device
())
grad_master
=
torch
.
randn
(
grad_shape
,
device
=
get_current_device
())
torch
.
distributed
.
broadcast
(
grad_master
,
src
=
0
)
grad
=
torch
.
chunk
(
grad_master
,
DEPTH
,
dim
=
0
)[
i
]
grad
=
torch
.
chunk
(
grad
,
DEPTH
,
dim
=
0
)[
j
]
...
...
@@ -283,7 +280,6 @@ def check_vocab_parallel_classifier_no_given_weight():
rank
=
torch
.
distributed
.
get_rank
()
logger
=
get_dist_logger
()
device
=
get_current_device
()
dtype
=
torch
.
float32
INPUT_SIZE
=
HIDDEN_SIZE
input_parallel_mode
=
get_parallel_mode_from_env
(
INPUT_GROUP_3D
)
...
...
@@ -295,10 +291,10 @@ def check_vocab_parallel_classifier_no_given_weight():
k
=
global_context
.
get_local_rank
(
output_parallel_mode
)
layer
=
VocabParallelClassifier3D
(
INPUT_SIZE
,
VOCAB_SIZE
,
bias
=
True
)
layer
=
layer
.
to
(
dtype
).
to
(
device
)
layer
=
layer
.
to
(
device
)
layer_master
=
VanillaClassifier
(
INPUT_SIZE
,
VOCAB_SIZE
,
bias
=
True
)
layer_master
=
layer_master
.
to
(
dtype
).
to
(
device
)
layer_master
=
layer_master
.
to
(
device
)
weight_master
=
layer_master
.
weight
.
data
torch
.
distributed
.
broadcast
(
weight_master
,
src
=
0
)
...
...
@@ -312,7 +308,7 @@ def check_vocab_parallel_classifier_no_given_weight():
layer
.
bias
.
data
.
copy_
(
bias
)
A_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
,
INPUT_SIZE
)
A_master
=
torch
.
randn
(
A_shape
,
dtype
=
dtype
,
device
=
device
)
A_master
=
torch
.
randn
(
A_shape
,
device
=
device
)
torch
.
distributed
.
broadcast
(
A_master
,
src
=
0
)
A
=
torch
.
chunk
(
A_master
,
DEPTH
,
dim
=
0
)[
i
]
A
=
torch
.
chunk
(
A
,
DEPTH
,
dim
=-
1
)[
k
]
...
...
@@ -336,7 +332,7 @@ def check_vocab_parallel_classifier_no_given_weight():
logger
.
info
(
'Rank {} vocab parallel classifier (no given weight) forward: {}'
.
format
(
rank
,
check_equal
(
out
,
C
)))
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
device
)
grad_master
=
torch
.
randn
(
grad_shape
,
device
=
device
)
torch
.
distributed
.
broadcast
(
grad_master
,
src
=
0
)
grad
=
torch
.
chunk
(
grad_master
,
DEPTH
,
dim
=
0
)[
i
]
grad
=
torch
.
chunk
(
grad
,
DEPTH
,
dim
=-
1
)[
j
]
...
...
@@ -455,7 +451,6 @@ def check_vocab_parallel_classifier_given_embed_weight():
rank
=
torch
.
distributed
.
get_rank
()
logger
=
get_dist_logger
()
device
=
get_current_device
()
dtype
=
torch
.
float32
input_parallel_mode
=
get_parallel_mode_from_env
(
INPUT_GROUP_3D
)
weight_parallel_mode
=
get_parallel_mode_from_env
(
WEIGHT_GROUP_3D
)
...
...
@@ -466,10 +461,10 @@ def check_vocab_parallel_classifier_given_embed_weight():
k
=
global_context
.
get_local_rank
(
output_parallel_mode
)
embed
=
VocabParallelEmbedding3D
(
VOCAB_SIZE
,
HIDDEN_SIZE
)
embed
=
embed
.
to
(
dtype
).
to
(
device
)
embed
=
embed
.
to
(
device
)
embed_master
=
torch
.
nn
.
Embedding
(
VOCAB_SIZE
,
HIDDEN_SIZE
)
embed_master
=
embed_master
.
to
(
dtype
).
to
(
device
)
embed_master
=
embed_master
.
to
(
device
)
weight_master
=
embed_master
.
weight
.
data
torch
.
distributed
.
broadcast
(
weight_master
,
src
=
0
)
...
...
@@ -479,10 +474,10 @@ def check_vocab_parallel_classifier_given_embed_weight():
embed
.
weight
.
data
.
copy_
(
weight
)
layer
=
VocabParallelClassifier3D
(
HIDDEN_SIZE
,
VOCAB_SIZE
,
weight
=
embed
.
weight
,
bias
=
False
)
layer
=
layer
.
to
(
dtype
).
to
(
device
)
layer
=
layer
.
to
(
device
)
layer_master
=
VanillaClassifier
(
HIDDEN_SIZE
,
VOCAB_SIZE
,
weight
=
embed_master
.
weight
,
bias
=
False
)
layer_master
=
layer_master
.
to
(
dtype
).
to
(
device
)
layer_master
=
layer_master
.
to
(
device
)
A_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
)
A_master
=
torch
.
randint
(
VOCAB_SIZE
,
A_shape
,
device
=
device
)
...
...
@@ -504,7 +499,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
logger
.
info
(
'Rank {} vocab parallel classifier (given embed weight) forward: {}'
.
format
(
rank
,
check_equal
(
out
,
C
)))
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
device
)
grad_master
=
torch
.
randn
(
grad_shape
,
device
=
device
)
torch
.
distributed
.
broadcast
(
grad_master
,
src
=
0
)
grad
=
torch
.
chunk
(
grad_master
,
DEPTH
,
dim
=
0
)[
i
]
grad
=
torch
.
chunk
(
grad
,
DEPTH
,
dim
=-
1
)[
j
]
...
...
@@ -546,12 +541,12 @@ def check_patch_embed():
i
=
global_context
.
get_local_rank
(
weight_parallel_mode
)
k
=
global_context
.
get_local_rank
(
output_parallel_mode
)
layer
=
PatchEmbedding3D
(
IMG_SIZE
,
4
,
3
,
HIDDEN_SIZE
,
dtype
=
dtype
)
layer
=
PatchEmbedding3D
(
IMG_SIZE
,
4
,
3
,
HIDDEN_SIZE
)
torch
.
nn
.
init
.
ones_
(
layer
.
cls_token
)
torch
.
nn
.
init
.
ones_
(
layer
.
pos_embed
)
layer
=
layer
.
to
(
device
)
layer_master
=
VanillaPatchEmbedding
(
IMG_SIZE
,
4
,
3
,
HIDDEN_SIZE
,
dtype
=
dtype
)
layer_master
=
VanillaPatchEmbedding
(
IMG_SIZE
,
4
,
3
,
HIDDEN_SIZE
)
torch
.
nn
.
init
.
ones_
(
layer_master
.
cls_token
)
torch
.
nn
.
init
.
ones_
(
layer_master
.
pos_embed
)
layer_master
=
layer_master
.
to
(
device
)
...
...
@@ -566,7 +561,7 @@ def check_patch_embed():
layer
.
bias
.
data
.
copy_
(
proj_bias
)
A_shape
=
(
BATCH_SIZE
,
3
,
IMG_SIZE
,
IMG_SIZE
)
A_master
=
torch
.
randn
(
A_shape
,
dtype
=
dtype
,
device
=
device
)
A_master
=
torch
.
randn
(
A_shape
,
device
=
device
)
torch
.
distributed
.
broadcast
(
A_master
,
src
=
0
)
A
=
A_master
.
clone
()
...
...
@@ -586,7 +581,7 @@ def check_patch_embed():
logger
.
info
(
'Rank {} patch embed forward: {}'
.
format
(
rank
,
check_equal
(
out
,
C
)))
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
device
)
grad_master
=
torch
.
randn
(
grad_shape
,
device
=
device
)
torch
.
distributed
.
broadcast
(
grad_master
,
src
=
0
)
grad
=
torch
.
chunk
(
grad_master
,
DEPTH
,
dim
=
0
)[
i
]
grad
=
torch
.
chunk
(
grad
,
DEPTH
,
dim
=-
1
)[
k
]
...
...
@@ -639,9 +634,9 @@ def check_embed():
k
=
global_context
.
get_local_rank
(
output_parallel_mode
)
layer
=
Embedding3D
(
VOCAB_SIZE
,
HIDDEN_SIZE
)
layer
=
layer
.
to
(
dtype
).
to
(
device
)
layer
=
layer
.
to
(
device
)
layer_master
=
torch
.
nn
.
Embedding
(
VOCAB_SIZE
,
HIDDEN_SIZE
)
layer_master
=
layer_master
.
to
(
dtype
).
to
(
device
)
layer_master
=
layer_master
.
to
(
device
)
weight_master
=
layer_master
.
weight
.
data
torch
.
distributed
.
broadcast
(
weight_master
,
src
=
0
)
...
...
@@ -669,7 +664,7 @@ def check_embed():
logger
.
info
(
'Rank {} embed forward: {}'
.
format
(
rank
,
check_equal
(
out
,
C
)))
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
device
)
grad_master
=
torch
.
randn
(
grad_shape
,
device
=
device
)
torch
.
distributed
.
broadcast
(
grad_master
,
src
=
0
)
grad
=
torch
.
chunk
(
grad_master
,
DEPTH
,
dim
=
0
)[
i
]
grad
=
torch
.
chunk
(
grad
,
DEPTH
,
dim
=-
1
)[
k
]
...
...
@@ -686,10 +681,7 @@ def check_embed():
B_grad
=
layer_master
.
weight
.
grad
B_grad
=
torch
.
chunk
(
B_grad
,
DEPTH
,
dim
=-
1
)[
k
]
if
j
==
k
:
logger
.
info
(
'Rank {} embed backward (weight_grad): {}'
.
format
(
rank
,
check_equal
(
B_grad
,
layer
.
weight
.
grad
)))
else
:
logger
.
info
(
'Rank {} embed backward (weight_grad): {}'
.
format
(
rank
,
layer
.
weight
.
grad
is
None
))
logger
.
info
(
'Rank {} embed backward (weight_grad): {}'
.
format
(
rank
,
check_equal
(
B_grad
,
layer
.
weight
.
grad
)))
return
fwd_end
-
fwd_start
,
bwd_end
-
bwd_start
...
...
@@ -709,9 +701,9 @@ def check_vocab_parallel_embed():
k
=
global_context
.
get_local_rank
(
output_parallel_mode
)
layer
=
VocabParallelEmbedding3D
(
VOCAB_SIZE
,
HIDDEN_SIZE
)
layer
=
layer
.
to
(
dtype
).
to
(
device
)
layer
=
layer
.
to
(
device
)
layer_master
=
torch
.
nn
.
Embedding
(
VOCAB_SIZE
,
HIDDEN_SIZE
)
layer_master
=
layer_master
.
to
(
dtype
).
to
(
device
)
layer_master
=
layer_master
.
to
(
device
)
weight_master
=
layer_master
.
weight
.
data
torch
.
distributed
.
broadcast
(
weight_master
,
src
=
0
)
...
...
@@ -741,7 +733,7 @@ def check_vocab_parallel_embed():
logger
.
info
(
'Rank {} vocab parallel embed forward: {}'
.
format
(
rank
,
check_equal
(
out
,
C
)))
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
device
)
grad_master
=
torch
.
randn
(
grad_shape
,
device
=
device
)
torch
.
distributed
.
broadcast
(
grad_master
,
src
=
0
)
grad
=
torch
.
chunk
(
grad_master
,
DEPTH
,
dim
=
0
)[
i
]
grad
=
torch
.
chunk
(
grad
,
DEPTH
,
dim
=-
1
)[
k
]
...
...
@@ -771,7 +763,6 @@ def check_loss():
rank
=
torch
.
distributed
.
get_rank
()
logger
=
get_dist_logger
()
device
=
get_current_device
()
dtype
=
torch
.
float32
input_parallel_mode
=
get_parallel_mode_from_env
(
INPUT_GROUP_3D
)
weight_parallel_mode
=
get_parallel_mode_from_env
(
WEIGHT_GROUP_3D
)
...
...
@@ -783,8 +774,8 @@ def check_loss():
criterion_master
=
torch
.
nn
.
CrossEntropyLoss
()
out_shape
=
(
BATCH_SIZE
,
NUM_CLASSES
)
out_master
=
torch
.
randn
(
out_shape
,
dtype
=
dtype
,
device
=
device
)
target_master
=
torch
.
randint
(
NUM_CLASSES
,
(
BATCH_SIZE
,),
dtype
=
torch
.
long
,
device
=
device
)
out_master
=
torch
.
randn
(
out_shape
,
device
=
device
)
target_master
=
torch
.
randint
(
NUM_CLASSES
,
(
BATCH_SIZE
,
),
dtype
=
torch
.
long
,
device
=
device
)
torch
.
distributed
.
broadcast
(
out_master
,
src
=
0
)
torch
.
distributed
.
broadcast
(
target_master
,
src
=
0
)
out
=
torch
.
chunk
(
out_master
,
DEPTH
,
dim
=
0
)[
i
]
...
...
@@ -836,8 +827,8 @@ def check_vocab_parallel_loss():
criterion_master
=
torch
.
nn
.
CrossEntropyLoss
()
out_shape
=
(
BATCH_SIZE
,
NUM_CLASSES
)
out_master
=
torch
.
randn
(
out_shape
,
dtype
=
dtype
,
device
=
device
)
target_master
=
torch
.
randint
(
NUM_CLASSES
,
(
BATCH_SIZE
,),
dtype
=
torch
.
long
,
device
=
device
)
out_master
=
torch
.
randn
(
out_shape
,
device
=
device
)
target_master
=
torch
.
randint
(
NUM_CLASSES
,
(
BATCH_SIZE
,
),
dtype
=
torch
.
long
,
device
=
device
)
torch
.
distributed
.
broadcast
(
out_master
,
src
=
0
)
torch
.
distributed
.
broadcast
(
target_master
,
src
=
0
)
out
=
torch
.
chunk
(
out_master
,
DEPTH
,
dim
=
0
)[
i
]
...
...
tests/test_layers/test_3d/checks_3d/common.py
View file @
0b8161fa
...
...
@@ -12,8 +12,8 @@ NUM_BLOCKS = 2
IMG_SIZE
=
16
VOCAB_SIZE
=
16
def
check_equal
(
A
,
B
):
eq
=
torch
.
allclose
(
A
,
B
,
rtol
=
1e-3
,
atol
=
1e-2
)
assert
eq
return
eq
assert
eq
,
f
"
\n
A =
{
A
}
\n
B =
{
B
}
"
return
eq
\ No newline at end of file
tests/test_layers/test_3d/test_3d.py
View file @
0b8161fa
...
...
@@ -10,9 +10,8 @@ from colossalai.initialize import launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.utils
import
free_port
from
colossalai.testing
import
rerun_if_address_is_in_use
,
skip_if_not_enough_gpus
from
checks_3d.check_layer_3d
import
(
check_classifier_given_embed_weight
,
check_classifier_no_given_weight
,
check_embed
,
check_layernorm
,
check_linear
,
check_loss
,
check_patch_embed
,
check_vocab_parallel_classifier_given_embed_weight
,
from
checks_3d.check_layer_3d
import
(
check_classifier_no_given_weight
,
check_embed
,
check_layernorm
,
check_linear
,
check_loss
,
check_patch_embed
,
check_vocab_parallel_classifier_given_embed_weight
,
check_vocab_parallel_classifier_no_given_weight
,
check_vocab_parallel_embed
,
check_vocab_parallel_loss
)
...
...
@@ -30,7 +29,6 @@ def check_layer():
check_layernorm
()
check_classifier_no_given_weight
()
check_vocab_parallel_classifier_no_given_weight
()
check_classifier_given_embed_weight
()
check_vocab_parallel_classifier_given_embed_weight
()
check_embed
()
check_patch_embed
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment