Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
404ecbdc
Commit
404ecbdc
authored
Oct 28, 2021
by
zbian
Browse files
Migrated project
parent
2ebaefc5
Changes
409
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3809 additions
and
0 deletions
+3809
-0
colossalai/nn/layer/base_layer.py
colossalai/nn/layer/base_layer.py
+27
-0
colossalai/nn/layer/parallel_1d/__init__.py
colossalai/nn/layer/parallel_1d/__init__.py
+5
-0
colossalai/nn/layer/parallel_1d/_utils.py
colossalai/nn/layer/parallel_1d/_utils.py
+15
-0
colossalai/nn/layer/parallel_1d/layers.py
colossalai/nn/layer/parallel_1d/layers.py
+166
-0
colossalai/nn/layer/parallel_2d/__init__.py
colossalai/nn/layer/parallel_2d/__init__.py
+11
-0
colossalai/nn/layer/parallel_2d/_operation.py
colossalai/nn/layer/parallel_2d/_operation.py
+522
-0
colossalai/nn/layer/parallel_2d/_transformer.py
colossalai/nn/layer/parallel_2d/_transformer.py
+220
-0
colossalai/nn/layer/parallel_2d/_utils.py
colossalai/nn/layer/parallel_2d/_utils.py
+23
-0
colossalai/nn/layer/parallel_2d/_vit.py
colossalai/nn/layer/parallel_2d/_vit.py
+391
-0
colossalai/nn/layer/parallel_2d/layers.py
colossalai/nn/layer/parallel_2d/layers.py
+258
-0
colossalai/nn/layer/parallel_2p5d/__init__.py
colossalai/nn/layer/parallel_2p5d/__init__.py
+13
-0
colossalai/nn/layer/parallel_2p5d/_operation.py
colossalai/nn/layer/parallel_2p5d/_operation.py
+535
-0
colossalai/nn/layer/parallel_2p5d/_transformer.py
colossalai/nn/layer/parallel_2p5d/_transformer.py
+206
-0
colossalai/nn/layer/parallel_2p5d/_utils.py
colossalai/nn/layer/parallel_2p5d/_utils.py
+25
-0
colossalai/nn/layer/parallel_2p5d/_vit.py
colossalai/nn/layer/parallel_2p5d/_vit.py
+351
-0
colossalai/nn/layer/parallel_2p5d/layers.py
colossalai/nn/layer/parallel_2p5d/layers.py
+266
-0
colossalai/nn/layer/parallel_3d/__init__.py
colossalai/nn/layer/parallel_3d/__init__.py
+9
-0
colossalai/nn/layer/parallel_3d/_operation.py
colossalai/nn/layer/parallel_3d/_operation.py
+349
-0
colossalai/nn/layer/parallel_3d/_utils.py
colossalai/nn/layer/parallel_3d/_utils.py
+49
-0
colossalai/nn/layer/parallel_3d/_vit.py
colossalai/nn/layer/parallel_3d/_vit.py
+368
-0
No files found.
colossalai/nn/layer/base_layer.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
torch.nn
as
nn
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
class
ParallelLayer
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
data_parallel_rank
=
0
if
not
gpc
.
is_initialized
(
ParallelMode
.
DATA
)
else
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
self
.
data_parallel_size
=
1
if
not
gpc
.
is_initialized
(
ParallelMode
.
DATA
)
else
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
self
.
tensor_parallel_rank
=
0
if
not
gpc
.
is_initialized
(
ParallelMode
.
TENSOR
)
else
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
self
.
tensor_parallel_size
=
1
if
not
gpc
.
is_initialized
(
ParallelMode
.
TENSOR
)
else
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
self
.
pipeline_parallel_rank
=
0
if
not
gpc
.
is_initialized
(
ParallelMode
.
PIPELINE
)
else
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
self
.
pipeline_parallel_size
=
1
if
not
gpc
.
is_initialized
(
ParallelMode
.
PIPELINE
)
else
gpc
.
get_world_size
(
ParallelMode
.
PIPELINE
)
colossalai/nn/layer/parallel_1d/__init__.py
0 → 100644
View file @
404ecbdc
from
.layers
import
Linear1D_Col
,
Linear1D_Row
__all__
=
[
'Linear1D_Col'
,
'Linear1D_Row'
,
]
colossalai/nn/layer/parallel_1d/_utils.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from
.._common_utils
import
divide
def
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
,
rank
):
index_f
=
rank
*
per_partition_vocab_size
index_l
=
index_f
+
per_partition_vocab_size
return
index_f
,
index_l
def
vocab_range_from_global_vocab_size
(
global_vocab_size
,
rank
,
world_size
):
per_partition_vocab_size
=
divide
(
global_vocab_size
,
world_size
)
return
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
,
rank
)
colossalai/nn/layer/parallel_1d/layers.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.init
as
init
from
torch
import
Tensor
from
torch.nn.parameter
import
Parameter
from
typing
import
Tuple
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.registry
import
LAYERS
from
colossalai.utils
import
get_current_device
from
.._common_utils
import
divide
from
.._parallel_utilities
import
reduce_grad
,
reduce_input
,
gather_forward_split_backward
,
\
split_forward_gather_backward
from
..base_layer
import
ParallelLayer
class
Linear1D_Col
(
ParallelLayer
):
"""Linear layer with column parallelism.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`.
:param in_features: first dimension of matrix A.
:type in_features: int
:param output_size: second dimension of matrix A.
:type output_size: int
:param bias: If true, add bias, defaults to True
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param gather_output: If true, call all-gether on output and make Y avaiable
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
:type gather_output: bool, optional
"""
def
__init__
(
self
,
in_features
:
int
,
output_size
:
int
,
bias
:
bool
=
True
,
dtype
:
torch
.
dtype
=
None
,
gather_output
:
bool
=
False
):
super
().
__init__
()
# Keep input parameters
self
.
input_size
=
in_features
self
.
output_size
=
output_size
self
.
gather_output
=
gather_output
self
.
skip_bias_add
=
not
bias
world_size
=
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
self
.
output_size_per_partition
=
divide
(
output_size
,
world_size
)
# Parameters.
# Initialize weight.
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
self
.
input_size
,
**
factory_kwargs
))
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
**
factory_kwargs
))
# Always initialize bias to zero.
with
torch
.
no_grad
():
self
.
bias
.
zero_
()
else
:
self
.
register_parameter
(
'bias'
,
None
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tuple
[
Tensor
,
Tensor
]:
# Set up backprop all-reduce.
input_parallel
=
reduce_grad
(
input_
,
ParallelMode
.
PARALLEL_1D
)
# Matrix multiply.
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
bias
)
if
self
.
gather_output
:
# All-gather across the partitions.
output
=
gather_forward_split_backward
(
output_parallel
,
ParallelMode
.
PARALLEL_1D
,
dim
=-
1
)
else
:
output
=
output_parallel
if
self
.
skip_bias_add
:
return
output
,
self
.
bias
else
:
return
output
@
LAYERS
.
register_module
class
Linear1D_Row
(
ParallelLayer
):
""" Linear layer with row parallelism
:param in_features: size of each input sample
:type in_features: int
:param out_features: size of each output sample
:type out_features: int
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param parallel_input: If set to ``False``, it's assumed that the input is splitted, defaults to False
:type parallel_input: bool, optional
"""
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
bias
:
bool
=
True
,
dtype
:
torch
.
dtype
=
None
,
parallel_input
:
bool
=
False
):
super
().
__init__
()
# Keep input parameters
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
parallel_input
=
parallel_input
self
.
skip_bias_add
=
not
bias
# Divide the weight matrix along the last dimension.
world_size
=
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
self
.
input_size_per_partition
=
divide
(
in_features
,
world_size
)
# Parameters.
# Initialize weight.
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
self
.
input_size_per_partition
,
**
factory_kwargs
))
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
out_features
,
**
factory_kwargs
))
# Always initialize bias to zero.
with
torch
.
no_grad
():
self
.
bias
.
zero_
()
else
:
self
.
register_parameter
(
'bias'
,
None
)
def
reset_parameters
(
self
)
->
None
:
init
.
xavier_normal_
(
self
.
weight
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
# Set up backprop all-reduce.
if
self
.
parallel_input
:
input_
=
input_
else
:
input_
=
split_forward_gather_backward
(
input_
,
ParallelMode
.
PARALLEL_1D
,
dim
=-
1
)
output_parallel
=
F
.
linear
(
input_
,
self
.
weight
)
output
=
reduce_input
(
output_parallel
,
ParallelMode
.
PARALLEL_1D
)
if
not
self
.
skip_bias_add
:
output
=
output
+
self
.
bias
return
output
colossalai/nn/layer/parallel_2d/__init__.py
0 → 100644
View file @
404ecbdc
from
._operation
import
Matmul_AB_2D
,
Matmul_ABT_2D
,
Matmul_ATB_2D
,
Add_Bias_2D
,
matmul_2d
from
._transformer
import
TransformerMLP2D
,
TransformerSelfAttention2D
,
TransformerLayer2D
from
._vit
import
ViTMLP2D
,
ViTSelfAttention2D
,
ViTHead2D
,
ViTPatchEmbedding2D
,
ViTTokenFuser2D
,
ViTInputSplitter2D
from
.layers
import
Linear2D
,
LayerNorm2D
__all__
=
[
'Matmul_AB_2D'
,
'Matmul_ABT_2D'
,
'Matmul_ATB_2D'
,
'Add_Bias_2D'
,
'matmul_2d'
,
'TransformerMLP2D'
,
'TransformerSelfAttention2D'
,
'TransformerLayer2D'
,
'ViTMLP2D'
,
'ViTSelfAttention2D'
,
'ViTHead2D'
,
'ViTPatchEmbedding2D'
,
'ViTTokenFuser2D'
,
'ViTInputSplitter2D'
,
'Linear2D'
,
'LayerNorm2D'
]
colossalai/nn/layer/parallel_2d/_operation.py
0 → 100644
View file @
404ecbdc
from
typing
import
Any
,
Tuple
import
torch
import
torch.distributed
as
dist
from
torch
import
Tensor
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
get_current_device
def
matmul_2d
(
a
,
b
,
summa_dim
,
out_shape
,
row_rank
=
None
,
col_rank
=
None
,
row_parallel_mode
=
ParallelMode
.
PARALLEL_2D_ROW
,
col_parallel_mode
=
ParallelMode
.
PARALLEL_2D_COL
,
):
"""Matrix multiplication for 2D parallelism
:param a: matrix :math:`A`
:type a: torch.tensor
:param b: matrix :math:`B`
:type b: torch.tensor
:param summa_dim: dimension of SUMMA fo 2D parallelism
:type summa_dim: int
:param out_shape: shape of output tensor
:type out_shape: tuple
:param row_rank: the rank of row, defaults to None
:type row_rank: int, optional
:param col_rank: the rank of column, defaults to None
:type col_rank: int, optional
:param row_parallel_mode: row parallel mode, defaults to ParallelMode.PARALLEL_2D_ROW
:type row_parallel_mode: str, optional
:param col_parallel_mode: column parallel mode, defaults to ParallelMode.PARALLEL_2D_COL
:type col_parallel_mode: str, optional
:return: :math:`C = AB`
:rtype: torch.tensor
"""
if
row_rank
is
None
:
row_rank
=
gpc
.
get_local_rank
(
col_parallel_mode
)
if
col_rank
is
None
:
col_rank
=
gpc
.
get_local_rank
(
row_parallel_mode
)
data_parallel_rank
=
0
if
not
gpc
.
is_initialized
(
ParallelMode
.
DATA
)
else
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
pipeline_parallel_rank
=
0
if
not
gpc
.
is_initialized
(
ParallelMode
.
PIPELINE
)
else
gpc
.
get_local_rank
(
ParallelMode
.
PIPELINE
)
pipeline_parallel_size
=
1
if
not
gpc
.
is_initialized
(
ParallelMode
.
PIPELINE
)
else
gpc
.
get_world_size
(
ParallelMode
.
PIPELINE
)
tensor_parallel_size
=
summa_dim
**
2
return
Matmul_AB_2D
(
a
,
b
,
summa_dim
,
out_shape
,
row_rank
,
col_rank
,
row_parallel_mode
,
col_parallel_mode
,
data_parallel_rank
,
pipeline_parallel_rank
,
pipeline_parallel_size
,
tensor_parallel_size
)
class
Matmul_AB_2D
(
torch
.
autograd
.
Function
):
"""Matrix multiplication for :math:`C = AB`
"""
@
staticmethod
def
forward
(
ctx
:
Any
,
A
:
Tensor
,
B
:
Tensor
,
summa_dim
:
int
,
out_shape
:
Tuple
[
int
,
...],
row_rank
:
int
,
col_rank
:
int
,
row_parallel_mode
:
ParallelMode
,
col_parallel_mode
:
ParallelMode
,
data_parallel_rank
:
int
,
pipeline_parallel_rank
:
int
,
pipeline_parallel_size
:
int
,
tensor_parallel_size
:
int
)
->
Tensor
:
# A: [b / q, s, h / q] -> [(b * s) / q, h / q]
# B: [h / q, s / q]
# C: [b / q, s, s / q] -> [(b * s) / q, s / q]
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
2
],
\
'Invalid shapes: A={}, B={} for AB.'
.
format
(
A
.
shape
,
B
.
shape
)
if
ctx
:
ctx
.
save_for_backward
(
A
,
B
)
A_shape
=
A
.
shape
A
=
A
.
reshape
((
-
1
,
A_shape
[
-
1
]))
B_shape
=
B
.
shape
B
=
B
.
reshape
((
-
1
,
B_shape
[
-
1
]))
C_shape
=
(
A
.
shape
[
0
],
B
.
shape
[
-
1
])
C
=
torch
.
zeros
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_current_device
())
for
i
in
range
(
summa_dim
):
A_temp
=
A
.
clone
()
B_temp
=
B
.
clone
()
src_a
=
i
+
summa_dim
*
row_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
broadcast
(
A_temp
,
src
=
src_a
,
group
=
gpc
.
get_group
(
row_parallel_mode
))
src_b
=
col_rank
+
summa_dim
*
i
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
broadcast
(
B_temp
,
src
=
src_b
,
group
=
gpc
.
get_group
(
col_parallel_mode
))
torch
.
addmm
(
C
,
A_temp
,
B_temp
,
out
=
C
)
out
=
C
.
reshape
(
out_shape
)
if
ctx
:
ctx
.
summa_dim
=
summa_dim
ctx
.
row_rank
=
row_rank
ctx
.
col_rank
=
col_rank
ctx
.
row_parallel_mode
=
row_parallel_mode
ctx
.
col_parallel_mode
=
col_parallel_mode
ctx
.
A_shape
=
A_shape
ctx
.
B_shape
=
B_shape
ctx
.
data_parallel_rank
=
data_parallel_rank
ctx
.
pipeline_parallel_rank
=
pipeline_parallel_rank
ctx
.
pipeline_parallel_size
=
pipeline_parallel_size
ctx
.
tensor_parallel_size
=
tensor_parallel_size
return
out
@
staticmethod
def
backward
(
ctx
:
Any
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
A
,
B
=
ctx
.
saved_tensors
A_grad
=
Matmul_ABT_2D
.
forward
(
None
,
output_grad
,
B
,
ctx
.
summa_dim
,
ctx
.
A_shape
,
ctx
.
row_rank
,
ctx
.
col_rank
,
ctx
.
row_parallel_mode
,
ctx
.
col_parallel_mode
,
ctx
.
data_parallel_rank
,
ctx
.
pipeline_parallel_rank
,
ctx
.
pipeline_parallel_size
,
ctx
.
tensor_parallel_size
)
B_grad
=
Matmul_ATB_2D
.
forward
(
None
,
A
,
output_grad
,
ctx
.
summa_dim
,
ctx
.
B_shape
,
ctx
.
row_rank
,
ctx
.
col_rank
,
ctx
.
row_parallel_mode
,
ctx
.
col_parallel_mode
,
ctx
.
data_parallel_rank
,
ctx
.
pipeline_parallel_rank
,
ctx
.
pipeline_parallel_size
,
ctx
.
tensor_parallel_size
)
return
A_grad
,
B_grad
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
Matmul_ABT_2D
(
torch
.
autograd
.
Function
):
"""Matrix multiplication for :math:`C = AB^T`
"""
@
staticmethod
def
forward
(
ctx
:
Any
,
A
:
Tensor
,
B
:
Tensor
,
summa_dim
:
int
,
out_shape
:
Tuple
[
int
,
...],
row_rank
:
int
,
col_rank
:
int
,
row_parallel_mode
:
ParallelMode
,
col_parallel_mode
:
ParallelMode
,
data_parallel_rank
:
int
,
pipeline_parallel_rank
:
int
,
pipeline_parallel_size
:
int
,
tensor_parallel_size
:
int
)
->
Tensor
:
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
],
\
'Invalid shapes: A={}, B={} for ABT.'
.
format
(
A
.
shape
,
B
.
shape
)
if
ctx
:
ctx
.
save_for_backward
(
A
,
B
)
A_shape
=
A
.
shape
A
=
A
.
reshape
((
-
1
,
A_shape
[
-
1
]))
B_shape
=
B
.
shape
B
=
B
.
reshape
((
-
1
,
B_shape
[
-
1
]))
C_shape
=
(
A
.
shape
[
0
],
B
.
shape
[
0
])
C
=
torch
.
empty
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_current_device
())
for
i
in
range
(
summa_dim
):
B_temp
=
B
.
clone
()
# C_temp = torch.zeros(C_shape, dtype=C.dtype, device=get_current_device())
src_b
=
col_rank
+
summa_dim
*
i
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
broadcast
(
B_temp
,
src
=
src_b
,
group
=
gpc
.
get_group
(
col_parallel_mode
))
C_temp
=
torch
.
matmul
(
A
,
B_temp
.
transpose
(
0
,
1
))
src_c
=
i
+
summa_dim
*
row_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
reduce
(
C_temp
,
dst
=
src_c
,
group
=
gpc
.
get_group
(
row_parallel_mode
))
if
i
==
col_rank
:
C
=
C_temp
.
clone
()
out
=
C
.
reshape
(
out_shape
)
if
ctx
:
ctx
.
summa_dim
=
summa_dim
ctx
.
row_rank
=
row_rank
ctx
.
col_rank
=
col_rank
ctx
.
row_parallel_mode
=
row_parallel_mode
ctx
.
col_parallel_mode
=
col_parallel_mode
ctx
.
A_shape
=
A_shape
ctx
.
B_shape
=
B_shape
ctx
.
data_parallel_rank
=
data_parallel_rank
ctx
.
pipeline_parallel_rank
=
pipeline_parallel_rank
ctx
.
pipeline_parallel_size
=
pipeline_parallel_size
ctx
.
tensor_parallel_size
=
tensor_parallel_size
return
out
@
staticmethod
def
backward
(
ctx
:
Any
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
A
,
B
=
ctx
.
saved_tensors
A_grad
=
Matmul_AB_2D
.
forward
(
None
,
output_grad
,
B
,
ctx
.
summa_dim
,
ctx
.
A_shape
,
ctx
.
row_rank
,
ctx
.
col_rank
,
ctx
.
row_parallel_mode
,
ctx
.
col_parallel_mode
,
ctx
.
data_parallel_rank
,
ctx
.
pipeline_parallel_rank
,
ctx
.
pipeline_parallel_size
,
ctx
.
tensor_parallel_size
)
B_grad
=
Matmul_ATB_2D
.
forward
(
None
,
output_grad
,
A
,
ctx
.
summa_dim
,
ctx
.
B_shape
,
ctx
.
row_rank
,
ctx
.
col_rank
,
ctx
.
row_parallel_mode
,
ctx
.
col_parallel_mode
,
ctx
.
data_parallel_rank
,
ctx
.
pipeline_parallel_rank
,
ctx
.
pipeline_parallel_size
,
ctx
.
tensor_parallel_size
)
return
A_grad
,
B_grad
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
Matmul_ATB_2D
(
torch
.
autograd
.
Function
):
"""Matrix multiplication for :math:`C = A^TB`
"""
@
staticmethod
def
forward
(
ctx
:
Any
,
A
:
Tensor
,
B
:
Tensor
,
summa_dim
:
int
,
out_shape
:
Tuple
[
int
,
...],
row_rank
:
int
,
col_rank
:
int
,
row_parallel_mode
:
ParallelMode
,
col_parallel_mode
:
ParallelMode
,
data_parallel_rank
:
int
,
pipeline_parallel_rank
:
int
,
pipeline_parallel_size
:
int
,
tensor_parallel_size
:
int
)
->
Tensor
:
assert
A
.
shape
[
-
2
]
==
B
.
shape
[
-
2
],
\
'Invalid shapes: A={}, B={} for ATB.'
.
format
(
A
.
shape
,
B
.
shape
)
if
ctx
:
ctx
.
save_for_backward
(
A
,
B
)
A_shape
=
A
.
shape
A
=
A
.
reshape
((
-
1
,
A_shape
[
-
1
]))
B_shape
=
B
.
shape
B
=
B
.
reshape
((
-
1
,
B_shape
[
-
1
]))
C_shape
=
(
A
.
shape
[
-
1
],
B
.
shape
[
-
1
])
C
=
torch
.
empty
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_current_device
())
for
i
in
range
(
summa_dim
):
A_temp
=
A
.
clone
()
# C_temp = torch.zeros(C_shape, dtype=C.dtype, device=get_current_device())
src_a
=
i
+
summa_dim
*
row_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
broadcast
(
A_temp
,
src
=
src_a
,
group
=
gpc
.
get_group
(
row_parallel_mode
))
C_temp
=
torch
.
matmul
(
A_temp
.
transpose
(
0
,
1
),
B
)
src_c
=
col_rank
+
summa_dim
*
i
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
reduce
(
C_temp
,
dst
=
src_c
,
group
=
gpc
.
get_group
(
col_parallel_mode
))
if
i
==
row_rank
:
C
=
C_temp
.
clone
()
out
=
C
.
reshape
(
out_shape
)
if
ctx
:
ctx
.
summa_dim
=
summa_dim
ctx
.
row_rank
=
row_rank
ctx
.
col_rank
=
col_rank
ctx
.
row_parallel_mode
=
row_parallel_mode
ctx
.
col_parallel_mode
=
col_parallel_mode
ctx
.
A_shape
=
A_shape
ctx
.
B_shape
=
B_shape
ctx
.
data_parallel_rank
=
data_parallel_rank
ctx
.
pipeline_parallel_rank
=
pipeline_parallel_rank
ctx
.
pipeline_parallel_size
=
pipeline_parallel_size
ctx
.
tensor_parallel_size
=
tensor_parallel_size
return
out
@
staticmethod
def
backward
(
ctx
:
Any
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
A
,
B
=
ctx
.
saved_tensors
A_grad
=
Matmul_ABT_2D
.
forward
(
None
,
B
,
output_grad
,
ctx
.
summa_dim
,
ctx
.
A_shape
,
ctx
.
row_rank
,
ctx
.
col_rank
,
ctx
.
row_parallel_mode
,
ctx
.
col_parallel_mode
,
ctx
.
data_parallel_rank
,
ctx
.
pipeline_parallel_rank
,
ctx
.
pipeline_parallel_size
,
ctx
.
tensor_parallel_size
)
B_grad
=
Matmul_AB_2D
.
forward
(
None
,
A
,
output_grad
,
ctx
.
summa_dim
,
ctx
.
B_shape
,
ctx
.
row_rank
,
ctx
.
col_rank
,
ctx
.
row_parallel_mode
,
ctx
.
col_parallel_mode
,
ctx
.
data_parallel_rank
,
ctx
.
pipeline_parallel_rank
,
ctx
.
pipeline_parallel_size
,
ctx
.
tensor_parallel_size
)
return
A_grad
,
B_grad
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
Add_Bias_2D
(
torch
.
autograd
.
Function
):
"""Matrix add bias: :math:`C = A + b`
"""
@
staticmethod
def
forward
(
ctx
:
Any
,
input
:
Tensor
,
bias
:
Tensor
,
output_size_per_partition
:
int
,
row_rank
:
int
,
col_rank
:
int
,
row_parallel_mode
:
ParallelMode
,
col_parallel_mode
:
ParallelMode
,
skip_bias_add
:
bool
,
data_parallel_rank
:
int
,
pipeline_parallel_rank
:
int
,
pipeline_parallel_size
:
int
,
tensor_parallel_size
:
int
)
->
Tensor
:
if
row_rank
==
0
:
bias_temp
=
bias
.
clone
()
else
:
bias_temp
=
torch
.
zeros
(
output_size_per_partition
,
dtype
=
bias
.
dtype
,
device
=
get_current_device
())
src_rank
=
col_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
broadcast
(
bias_temp
,
src
=
src_rank
,
group
=
gpc
.
get_group
(
col_parallel_mode
))
ctx
.
row_rank
=
row_rank
ctx
.
col_rank
=
col_rank
ctx
.
row_parallel_mode
=
row_parallel_mode
ctx
.
col_parallel_mode
=
col_parallel_mode
ctx
.
bias
=
skip_bias_add
ctx
.
data_parallel_rank
=
data_parallel_rank
ctx
.
pipeline_parallel_rank
=
pipeline_parallel_rank
ctx
.
pipeline_parallel_size
=
pipeline_parallel_size
ctx
.
tensor_parallel_size
=
tensor_parallel_size
if
skip_bias_add
:
return
bias_temp
else
:
output
=
input
+
bias_temp
return
output
@
staticmethod
def
backward
(
ctx
:
Any
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
row_rank
=
ctx
.
row_rank
col_rank
=
ctx
.
col_rank
row_parallel_mode
=
ctx
.
row_parallel_mode
col_parallel_mode
=
ctx
.
col_parallel_mode
data_parallel_rank
=
ctx
.
data_parallel_rank
pipeline_parallel_rank
=
ctx
.
pipeline_parallel_rank
pipeline_parallel_size
=
ctx
.
pipeline_parallel_size
tensor_parallel_size
=
ctx
.
tensor_parallel_size
if
ctx
.
bias
:
dst_rank
=
col_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
reduce
(
output_grad
,
dst
=
dst_rank
,
group
=
gpc
.
get_group
(
col_parallel_mode
))
if
row_rank
==
0
:
return
None
,
output_grad
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
else
:
# for compatibility with zero optimizer, no grad should be None
grad_tmp
=
torch
.
zeros_like
(
output_grad
)
return
None
,
grad_tmp
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
else
:
reduce_dim
=
tuple
(
range
(
output_grad
.
ndim
-
1
))
reduce
=
torch
.
sum
(
output_grad
,
dim
=
reduce_dim
)
dst_rank
=
col_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
reduce
(
reduce
,
dst
=
dst_rank
,
group
=
gpc
.
get_group
(
col_parallel_mode
))
if
row_rank
==
0
:
return
output_grad
,
reduce
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
else
:
# for compatibility with zero optimizer, no grad should be None
reduce_tmp
=
torch
.
zeros_like
(
reduce
)
return
output_grad
,
reduce_tmp
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
_LayerNorm_2D
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
:
Any
,
input
:
Tensor
,
E_x
:
Tensor
,
Var_x
:
Tensor
,
hidden_size
:
int
,
row_parallel_mode
:
ParallelMode
,
col_parallel_mode
:
ParallelMode
)
->
Tensor
:
input
=
input
-
E_x
# in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps)
ctx
.
normalized_shape
=
hidden_size
output
=
input
*
Var_x
ctx
.
save_for_backward
(
output
,
Var_x
)
ctx
.
row_parallel_mode
=
row_parallel_mode
ctx
.
col_parallel_mode
=
col_parallel_mode
return
output
@
staticmethod
def
backward
(
ctx
:
Any
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
row_parallel_mode
=
ctx
.
row_parallel_mode
col_parallel_mode
=
ctx
.
col_parallel_mode
x
,
Var_x
=
ctx
.
saved_tensors
# in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x
output_grad_sum
=
torch
.
sum
(
output_grad
,
dim
=-
1
,
keepdim
=
True
)
torch
.
distributed
.
all_reduce
(
output_grad_sum
,
group
=
gpc
.
get_group
(
row_parallel_mode
))
output_grad_sum
/=
ctx
.
normalized_shape
output_grad_mul_x_sum
=
torch
.
sum
(
output_grad
*
x
,
dim
=-
1
,
keepdim
=
True
)
torch
.
distributed
.
all_reduce
(
output_grad_mul_x_sum
,
group
=
gpc
.
get_group
(
row_parallel_mode
))
output_grad_mul_x_sum
/=
ctx
.
normalized_shape
input_grad
=
output_grad
.
clone
()
input_grad
-=
x
*
output_grad_mul_x_sum
input_grad
-=
output_grad_sum
input_grad
*=
Var_x
return
input_grad
,
None
,
None
,
None
,
None
,
None
# class Sum_2D(torch.autograd.Function):
#
# @staticmethod
# def forward(ctx: Any,
# inputs: Tensor,
# dim: int,
# summa_dim: int,
# row_parallel_mode: ParallelMode,
# keepdim: bool = False) -> Tensor:
# # input: [b/q, s, h/q]
# empty_cache()
# ctx.save_for_backward(inputs)
# # sum: [b/q, s]
# out = torch.sum(inputs, dim=dim, keepdim=keepdim)
# torch.distributed.all_reduce(out, group=gpc.get_group(row_parallel_mode))
# return out
#
# @staticmethod
# def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# with torch.no_grad():
# inputs = ctx.saved_tensors
# input_grad = torch.ones(inputs.shape, dtype=output_grad.dtype)
# return input_grad, None, None, None, None, None
class
_ViT_Split_Input_2D
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
:
Any
,
inputs
:
Tensor
,
batch_size
:
int
,
summa_dim
:
int
,
col_parallel_mode
:
ParallelMode
)
->
Tensor
:
# inputs: [b, s, h/q]
# output: [b/q, s, h/q]
ctx
.
BATCH_SIZE
=
batch_size
ctx
.
summa_dim
=
summa_dim
ctx
.
col_parallel_mode
=
col_parallel_mode
row_rank
=
gpc
.
get_local_rank
(
col_parallel_mode
)
output
=
torch
.
chunk
(
inputs
,
summa_dim
,
dim
=
0
)[
row_rank
]
output
=
output
.
clone
()
return
output
@
staticmethod
def
backward
(
ctx
:
Any
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
# output_grad: [b/q, s, h/q]
# grads: [b, s, h/q]
grads_shape
=
(
ctx
.
BATCH_SIZE
,)
+
output_grad
.
shape
[
1
:]
grads
=
torch
.
empty
(
grads_shape
,
dtype
=
output_grad
.
dtype
,
device
=
get_current_device
())
dist
.
all_gather
(
list
(
grads
.
chunk
(
ctx
.
summa_dim
,
dim
=
0
)),
output_grad
.
contiguous
(),
group
=
gpc
.
get_group
(
ctx
.
col_parallel_mode
))
return
grads
,
None
,
None
,
None
colossalai/nn/layer/parallel_2d/_transformer.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
math
import
torch
from
torch
import
nn
as
nn
,
Tensor
from
colossalai.nn.layer._common_utils
import
divide
,
ACT2FN
from
colossalai.nn.layer.parallel_2d._utils
import
assert_summa_initialization
,
get_summa_dim_from_env
from
colossalai.registry
import
LAYERS
from
.layers
import
Linear2D
,
LayerNorm2D
from
..base_layer
import
ParallelLayer
@
LAYERS
.
register_module
class
TransformerMLP2D
(
ParallelLayer
):
"""
MLP will take the input with h hidden state, project it to mlp_ratio * h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
:param in_features: the size of input tensor
:type in_features: int
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: int, optional
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param dropout_prob: dropout probability, defaults to 0.
:type dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False
:type skip_bias_add: bool, optional
"""
def
__init__
(
self
,
in_features
:
int
,
mlp_ratio
:
int
=
4.0
,
act_func
:
str
=
'gelu'
,
dropout_prob
:
float
=
0.
,
dtype
=
None
,
skip_bias_add
:
bool
=
False
):
super
().
__init__
()
assert_summa_initialization
()
self
.
summa_dim
=
get_summa_dim_from_env
()
self
.
in_features
=
in_features
self
.
skip_bias_add
=
skip_bias_add
# Project to h * mlp_ratio.
self
.
dense_1
=
Linear2D
(
in_features
,
int
(
mlp_ratio
*
in_features
),
dtype
=
dtype
,
skip_bias_add
=
self
.
skip_bias_add
)
assert
act_func
in
ACT2FN
.
keys
(),
f
'Invalid value for argument act_func, '
\
f
'activation function can only be
{
list
(
ACT2FN
.
keys
())
}
'
self
.
activation_func
=
ACT2FN
[
act_func
]
# Project back to h.
self
.
dense_2
=
Linear2D
(
int
(
mlp_ratio
*
in_features
),
in_features
,
dtype
=
dtype
,
skip_bias_add
=
self
.
skip_bias_add
)
self
.
dropout
=
nn
.
Dropout
(
dropout_prob
)
self
.
layernorm
=
LayerNorm2D
(
in_features
,
dtype
=
dtype
)
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
if
self
.
skip_bias_add
:
intermediate_output
,
_
=
self
.
dense_1
(
x
)
else
:
intermediate_output
=
self
.
dense_1
(
x
)
intermediate_output
=
self
.
activation_func
(
intermediate_output
)
if
self
.
skip_bias_add
:
output
,
_
=
self
.
dense_2
(
intermediate_output
)
else
:
output
=
self
.
dense_2
(
intermediate_output
)
output
=
self
.
dropout
(
output
)
output
=
self
.
layernorm
(
x
+
output
)
return
output
@
LAYERS
.
register_module
class
TransformerSelfAttention2D
(
ParallelLayer
):
"""Self attention layer for 2D parallel Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layer
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layer
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def
__init__
(
self
,
hidden_size
:
int
,
num_attention_heads
:
int
,
attention_dropout_prob
:
float
,
hidden_dropout_prob
:
float
,
dtype
=
None
,
):
super
().
__init__
()
assert_summa_initialization
()
self
.
summa_dim
=
get_summa_dim_from_env
()
self
.
hidden_size
=
hidden_size
self
.
num_attention_heads
=
divide
(
num_attention_heads
,
self
.
summa_dim
)
self
.
attention_head_size
=
divide
(
hidden_size
,
num_attention_heads
)
self
.
all_head_size
=
self
.
num_attention_heads
*
self
.
attention_head_size
self
.
query_key_value
=
Linear2D
(
hidden_size
,
3
*
hidden_size
,
dtype
=
dtype
,
)
self
.
attention_dropout
=
nn
.
Dropout
(
attention_dropout_prob
)
self
.
dense
=
Linear2D
(
hidden_size
,
hidden_size
,
dtype
=
dtype
,
)
self
.
dropout
=
nn
.
Dropout
(
hidden_dropout_prob
)
self
.
layernorm
=
LayerNorm2D
(
hidden_size
,
dtype
=
dtype
)
def
forward
(
self
,
hidden_states
:
Tensor
,
attention_mask
:
Tensor
)
->
Tensor
:
query_key_value
=
self
.
query_key_value
(
hidden_states
)
new_qkv_shape
=
query_key_value
.
shape
[:
-
1
]
+
\
(
self
.
num_attention_heads
,
3
*
self
.
attention_head_size
)
query_key_value
=
query_key_value
.
view
(
new_qkv_shape
)
query_key_value
=
query_key_value
.
permute
((
0
,
2
,
1
,
3
))
query_layer
,
key_layer
,
value_layer
=
torch
.
chunk
(
query_key_value
,
3
,
dim
=-
1
)
attention_scores
=
torch
.
matmul
(
query_layer
,
key_layer
.
transpose
(
-
1
,
-
2
))
attention_scores
=
attention_scores
/
\
math
.
sqrt
(
self
.
attention_head_size
)
attention_scores
=
attention_scores
+
attention_mask
attention_probs
=
nn
.
Softmax
(
dim
=-
1
)(
attention_scores
)
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
context_layer
=
torch
.
matmul
(
attention_probs
,
value_layer
)
context_layer
=
context_layer
.
permute
((
0
,
2
,
1
,
3
)).
contiguous
()
new_context_layer_shape
=
context_layer
.
size
()[
:
-
2
]
+
(
self
.
all_head_size
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
output
=
self
.
dense
(
context_layer
)
output
=
self
.
dropout
(
output
)
attention_output
=
self
.
layernorm
(
hidden_states
+
output
)
return
attention_output
@
LAYERS
.
register_module
class
TransformerLayer2D
(
ParallelLayer
):
"""Transformer layer which contains a self-attention layer and a MLP layer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: float, optional
:param attention_dropout_prob: dropout probability for attention layer, defaults to 0.
:type attention_dropout_prob: float, optional
:param hidden_dropout_prob: dropout probability for attention layer, defaults to 0.
:type hidden_dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def
__init__
(
self
,
hidden_size
:
int
,
num_attention_heads
:
int
,
act_func
:
str
=
'gelu'
,
mlp_ratio
:
float
=
4.0
,
attention_dropout_prob
:
float
=
0.
,
hidden_dropout_prob
:
float
=
0.
,
dtype
=
None
,
):
super
().
__init__
()
self
.
attention
=
TransformerSelfAttention2D
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
attention_dropout_prob
=
attention_dropout_prob
,
hidden_dropout_prob
=
hidden_dropout_prob
,
dtype
=
dtype
,
)
self
.
mlp
=
TransformerMLP2D
(
in_features
=
hidden_size
,
dropout_prob
=
hidden_dropout_prob
,
act_func
=
act_func
,
mlp_ratio
=
mlp_ratio
,
dtype
=
dtype
,
)
def
forward
(
self
,
hidden_states
:
Tensor
,
attention_mask
:
Tensor
)
->
Tensor
:
attention_output
=
self
.
attention
(
hidden_states
,
attention_mask
)
output
=
self
.
mlp
(
attention_output
)
return
output
colossalai/nn/layer/parallel_2d/_utils.py
0 → 100644
View file @
404ecbdc
import
os
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.process_group_initializer.initializer_2d
import
SUMMA_DIM
from
colossalai.core
import
global_context
as
gpc
def
get_summa_dim_from_env
()
->
int
:
try
:
summa_dim
=
os
.
environ
[
SUMMA_DIM
]
summa_dim
=
int
(
summa_dim
)
assert
summa_dim
>
0
,
'SUMMA_DIM must be larger than zero'
return
summa_dim
except
KeyError
as
e
:
raise
EnvironmentError
(
'SUMMA_DIM is not found in the current environment, '
'please make sure that you have used the correct process group initializer'
)
def
assert_summa_initialization
():
assert
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_2D_COL
)
and
\
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_2D_ROW
),
\
'Both TWO_DIMENSION_COL and TWO_DIMENSION_ROW must be initialized by the process group initializer'
colossalai/nn/layer/parallel_2d/_vit.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
math
import
torch
from
torch
import
nn
as
nn
,
Tensor
,
distributed
as
dist
from
colossalai.context
import
seed
,
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.nn.layer._common_utils
import
divide
,
ACT2FN
from
colossalai.nn.layer.parallel_2d._utils
import
assert_summa_initialization
,
get_summa_dim_from_env
from
colossalai.nn.layer.vanilla_vision_transformer.layers
import
to_2tuple
from
colossalai.registry
import
LAYERS
from
colossalai.utils
import
checkpoint
from
colossalai.utils
import
get_current_device
from
._operation
import
_ViT_Split_Input_2D
from
.layers
import
Linear2D
from
.._common_utils
import
set_tensor_parallel_attribute
from
..base_layer
import
ParallelLayer
@
LAYERS
.
register_module
class
ViTMLP2D
(
ParallelLayer
):
"""MLP layer for 2D parallel Vision Transformer
:param in_features: size of each input sample
:type in_features: int
:param mlp_ratio: hidden size of MLP divided by embedding dim
:type mlp_ratio: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param dropout_prob: dropout probability, defaults to 0.
:type dropout_prob: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: whether to checkpoint the layer, defaults to False
:type checkpoint: bool, optional
"""
def
__init__
(
self
,
in_features
:
int
,
mlp_ratio
:
int
,
act_func
:
str
=
'gelu'
,
dropout_prob
:
float
=
0.
,
dtype
=
None
,
checkpoint
:
bool
=
False
):
super
().
__init__
()
assert_summa_initialization
()
self
.
summa_dim
=
get_summa_dim_from_env
()
self
.
in_features
=
in_features
self
.
mlp_ratio
=
mlp_ratio
self
.
checkpoint
=
checkpoint
# Project to mlp_ratio * h.
self
.
dense_1
=
Linear2D
(
self
.
in_features
,
self
.
mlp_ratio
*
self
.
in_features
,
dtype
=
dtype
,
)
self
.
act
=
ACT2FN
[
act_func
]
# Project back to h.
self
.
dense_2
=
Linear2D
(
self
.
mlp_ratio
*
self
.
in_features
,
self
.
in_features
,
dtype
=
dtype
,
)
self
.
dropout
=
nn
.
Dropout
(
dropout_prob
)
def
_forward
(
self
,
hidden_states
:
Tensor
)
->
Tensor
:
intermediate_output
=
self
.
dense_1
(
hidden_states
)
intermediate_output
=
self
.
act
(
intermediate_output
)
with
seed
(
ParallelMode
.
TENSOR
):
intermediate_output
=
self
.
dropout
(
intermediate_output
)
output
=
self
.
dense_2
(
intermediate_output
)
with
seed
(
ParallelMode
.
TENSOR
):
output
=
self
.
dropout
(
output
)
return
output
def
_checkpoint_forward
(
self
,
hidden_states
:
Tensor
)
->
Tensor
:
return
checkpoint
(
self
.
_forward
,
hidden_states
)
def
forward
(
self
,
hidden_states
:
Tensor
)
->
Tensor
:
if
self
.
checkpoint
:
return
self
.
_checkpoint_forward
(
hidden_states
)
else
:
return
self
.
_forward
(
hidden_states
)
@
LAYERS
.
register_module
class
ViTSelfAttention2D
(
ParallelLayer
):
"""Self-attention layer for 2D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layers
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layers
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: whether to checkpoint the layer, defaults to False
:type checkpoint: bool, optional
"""
def
__init__
(
self
,
hidden_size
:
int
,
num_attention_heads
:
int
,
attention_dropout_prob
:
float
,
hidden_dropout_prob
:
float
,
dtype
=
None
,
checkpoint
:
bool
=
False
):
super
().
__init__
()
assert_summa_initialization
()
self
.
summa_dim
=
get_summa_dim_from_env
()
self
.
hidden_size
=
hidden_size
self
.
num_attention_heads
=
divide
(
num_attention_heads
,
self
.
summa_dim
)
self
.
attention_head_size
=
divide
(
hidden_size
,
num_attention_heads
)
self
.
all_head_size
=
self
.
num_attention_heads
*
self
.
attention_head_size
self
.
checkpoint
=
checkpoint
self
.
query_key_value
=
Linear2D
(
hidden_size
,
3
*
hidden_size
,
dtype
=
dtype
,
)
self
.
attention_dropout
=
nn
.
Dropout
(
attention_dropout_prob
)
self
.
dense
=
Linear2D
(
hidden_size
,
hidden_size
,
dtype
=
dtype
,
)
self
.
dropout
=
nn
.
Dropout
(
hidden_dropout_prob
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
_forward
(
self
,
hidden_states
:
Tensor
)
->
Tensor
:
query_key_value
=
self
.
query_key_value
(
hidden_states
)
new_qkv_shape
=
query_key_value
.
shape
[:
-
1
]
+
\
(
self
.
num_attention_heads
,
3
*
self
.
attention_head_size
)
query_key_value
=
query_key_value
.
view
(
new_qkv_shape
)
query_key_value
=
query_key_value
.
permute
((
0
,
2
,
1
,
3
))
query_layer
,
key_layer
,
value_layer
=
torch
.
chunk
(
query_key_value
,
3
,
dim
=-
1
)
attention_scores
=
torch
.
matmul
(
query_layer
,
key_layer
.
transpose
(
-
1
,
-
2
))
attention_scores
=
attention_scores
/
\
math
.
sqrt
(
self
.
attention_head_size
)
attention_probs
=
self
.
softmax
(
attention_scores
)
with
seed
(
ParallelMode
.
TENSOR
):
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
context_layer
=
torch
.
matmul
(
attention_probs
,
value_layer
)
context_layer
=
context_layer
.
transpose
(
1
,
2
)
new_context_layer_shape
=
context_layer
.
size
()[
:
-
2
]
+
(
self
.
all_head_size
,)
context_layer
=
context_layer
.
reshape
(
new_context_layer_shape
)
output
=
self
.
dense
(
context_layer
)
with
seed
(
ParallelMode
.
TENSOR
):
output
=
self
.
dropout
(
output
)
return
output
def
_checkpoint_forward
(
self
,
hidden_states
:
Tensor
)
->
Tensor
:
return
checkpoint
(
self
.
_forward
,
hidden_states
)
def
forward
(
self
,
hidden_states
:
Tensor
)
->
Tensor
:
if
self
.
checkpoint
:
return
self
.
_checkpoint_forward
(
hidden_states
)
else
:
return
self
.
_forward
(
hidden_states
)
@
LAYERS
.
register_module
class
ViTHead2D
(
ParallelLayer
):
"""Output layer for 2D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_classes: number of classes
:type num_classes: int
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def
__init__
(
self
,
hidden_size
,
num_classes
,
dtype
=
None
,
):
super
().
__init__
()
assert_summa_initialization
()
self
.
summa_dim
=
get_summa_dim_from_env
()
self
.
linear
=
Linear2D
(
hidden_size
,
num_classes
,
dtype
=
dtype
,
)
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
x
=
x
[:,
0
]
x
=
self
.
linear
(
x
)
return
x
@
LAYERS
.
register_module
class
ViTPatchEmbedding2D
(
ParallelLayer
):
""" 2D Image to Patch Embedding
:param img_size: iamge size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param in_chans: number of channels of input image, defaults to 3
:type in_chans: int, optional
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
"""
def
__init__
(
self
,
img_size
,
patch_size
,
embed_dim
,
in_chans
=
3
,
flatten
=
True
):
super
().
__init__
()
img_size
=
to_2tuple
(
img_size
)
patch_size
=
to_2tuple
(
patch_size
)
assert_summa_initialization
()
self
.
summa_dim
=
get_summa_dim_from_env
()
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
grid_size
=
(
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
])
self
.
num_patches
=
self
.
grid_size
[
0
]
*
self
.
grid_size
[
1
]
self
.
flatten
=
flatten
self
.
embed_dim
=
embed_dim
//
self
.
summa_dim
with
seed
(
ParallelMode
.
TENSOR
):
# ensure the partitions are initialized differently
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
self
.
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
# sync
self
.
_broadcast_conv_params
()
self
.
proj
.
weight
.
register_hook
(
self
.
_sync_grad_during_backward
)
self
.
proj
.
bias
.
register_hook
(
self
.
_sync_grad_during_backward
)
def
_set_tensor_parallel_attribute
(
self
):
set_tensor_parallel_attribute
(
self
.
proj
.
weight
)
set_tensor_parallel_attribute
(
self
.
proj
.
bias
)
def
_broadcast_conv_params
(
self
)
->
None
:
self
.
to
(
get_current_device
())
ranks_in_col
=
gpc
.
get_ranks_in_group
(
ParallelMode
.
PARALLEL_2D_COL
)
dist
.
broadcast
(
self
.
proj
.
weight
,
src
=
ranks_in_col
[
0
],
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2D_COL
))
dist
.
broadcast
(
self
.
proj
.
bias
,
src
=
ranks_in_col
[
0
],
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2D_COL
))
def
_sync_grad_during_backward
(
self
,
grad
:
Tensor
)
->
None
:
dist
.
all_reduce
(
grad
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2D_COL
))
grad
=
grad
/
self
.
summa_dim
return
grad
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
B
,
C
,
H
,
W
=
x
.
shape
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
],
\
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
x
=
self
.
proj
(
x
)
if
self
.
flatten
:
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
# BCHW -> BNC
return
x
@
LAYERS
.
register_module
class
ViTTokenFuser2D
(
ParallelLayer
):
"""
Fuse cls token and pos embedding to the input
:param img_size: image size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param drop_rate: dropout probability, defaults to 0.
:type drop_rate: float, optional
"""
def
__init__
(
self
,
img_size
,
patch_size
,
embed_dim
,
drop_rate
=
0.
):
super
().
__init__
()
img_size
=
to_2tuple
(
img_size
)
patch_size
=
to_2tuple
(
patch_size
)
assert_summa_initialization
()
self
.
summa_dim
=
get_summa_dim_from_env
()
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
grid_size
=
(
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
])
self
.
num_patches
=
self
.
grid_size
[
0
]
*
self
.
grid_size
[
1
]
self
.
embed_dim
=
embed_dim
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
self
.
embed_dim
//
self
.
summa_dim
))
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
self
.
num_patches
+
1
,
self
.
embed_dim
//
self
.
summa_dim
))
# move to cuda before broadcast
self
.
to
(
get_current_device
())
# sync param in both forward and backward
_cls_token
=
self
.
cls_token
.
view
(
-
1
)
_pos_embed
=
self
.
pos_embed
.
view
(
-
1
)
self
.
_param
=
torch
.
cat
([
_cls_token
,
_pos_embed
],
dim
=
0
)
self
.
_broadcast_params
(
self
.
_param
)
self
.
_param
.
register_hook
(
self
.
_sync_grad_hook
)
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
self
.
_set_tensor_parallel_attribute
()
def
_set_tensor_parallel_attribute
(
self
):
set_tensor_parallel_attribute
(
self
.
cls_token
)
set_tensor_parallel_attribute
(
self
.
pos_embed
)
def
_broadcast_params
(
self
,
param
)
->
None
:
" broadcast to all column ranks for data consistency "
ranks_in_col
=
gpc
.
get_ranks_in_group
(
ParallelMode
.
PARALLEL_2D_COL
)
col_group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2D_COL
)
dist
.
broadcast
(
param
,
src
=
ranks_in_col
[
0
],
group
=
col_group
)
def
_sync_grad_hook
(
self
,
grad
)
->
None
:
dist
.
all_reduce
(
grad
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2D_COL
))
grad
=
grad
/
self
.
summa_dim
return
grad
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
# stole cls_tokens impl from Phil Wang, thanks
cls_token
=
self
.
cls_token
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
)
x
=
torch
.
cat
((
cls_token
,
x
),
dim
=
1
)
with
seed
(
ParallelMode
.
TENSOR
):
x
=
self
.
pos_drop
(
x
+
self
.
pos_embed
)
return
x
@
LAYERS
.
register_module
class
ViTInputSplitter2D
(
ParallelLayer
):
"""Split the input tensor for 2D parallel Vision Transformer
"""
def
__init__
(
self
):
super
().
__init__
()
assert_summa_initialization
()
self
.
summa_dim
=
get_summa_dim_from_env
()
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
batch_size
=
x
.
size
(
0
)
return
_ViT_Split_Input_2D
.
apply
(
x
,
batch_size
,
self
.
summa_dim
,
ParallelMode
.
PARALLEL_2D_COL
)
colossalai/nn/layer/parallel_2d/layers.py
0 → 100644
View file @
404ecbdc
import
math
import
torch
import
torch.distributed
as
dist
from
torch
import
Tensor
from
torch.nn
import
Parameter
,
init
as
init
from
colossalai.context
import
seed
,
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.registry
import
LAYERS
from
colossalai.utils
import
get_current_device
from
._operation
import
Matmul_AB_2D
,
Add_Bias_2D
,
_LayerNorm_2D
from
._utils
import
get_summa_dim_from_env
,
assert_summa_initialization
from
.._common_utils
import
divide
,
set_tensor_parallel_attribute
from
..base_layer
import
ParallelLayer
@
LAYERS
.
register_module
class
Linear2D
(
ParallelLayer
):
""" Linear layer for 2D parallelism
:param in_features: size of each input sample
:type in_features: int
:param out_features: size of each output sample
:type out_features: int
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False
:type skip_bias_add: bool, optional
"""
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
bias
:
bool
=
True
,
dtype
=
None
,
skip_bias_add
:
bool
=
False
):
super
().
__init__
()
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
skip_bias_add
=
skip_bias_add
# parallel settings
assert_summa_initialization
()
self
.
row_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_COL
)
self
.
col_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_ROW
)
self
.
summa_dim
=
get_summa_dim_from_env
()
# partitioning dimension
self
.
input_size_per_partition
=
divide
(
self
.
in_features
,
self
.
summa_dim
)
self
.
hidden_size_per_partition
=
divide
(
self
.
out_features
,
self
.
summa_dim
)
# create weight, shape: [k/q, h/q]
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
input_size_per_partition
,
self
.
hidden_size_per_partition
,
**
factory_kwargs
))
# create bias, shape: [h/q]
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
hidden_size_per_partition
,
**
factory_kwargs
))
else
:
self
.
register_parameter
(
'bias'
,
None
)
# initialize parameters
self
.
reset_parameters
()
self
.
_set_tensor_parallel_attributes
()
def
_set_tensor_parallel_attributes
(
self
):
set_tensor_parallel_attribute
(
self
.
weight
)
if
self
.
bias
is
not
None
:
set_tensor_parallel_attribute
(
self
.
bias
)
def
reset_parameters
(
self
)
->
None
:
# setting
fan_in
=
self
.
in_features
a
=
math
.
sqrt
(
5
)
nonlinearity
=
'leaky_relu'
# init weight
std
=
init
.
calculate_gain
(
nonlinearity
,
a
)
/
math
.
sqrt
(
fan_in
)
bound
=
math
.
sqrt
(
3.0
)
*
std
with
seed
(
ParallelMode
.
TENSOR
):
init
.
uniform_
(
self
.
weight
,
-
bound
,
bound
)
# init bias
if
self
.
bias
is
not
None
:
bound
=
1
/
math
.
sqrt
(
fan_in
)
if
fan_in
>
0
else
0
with
seed
(
ParallelMode
.
TENSOR
):
init
.
uniform_
(
self
.
bias
,
-
bound
,
bound
)
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
# input: [m/q, n/q, k/q]
# output: [m/q, n/q, h/q]
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
hidden_size_per_partition
,)
output
=
Matmul_AB_2D
.
apply
(
x
,
self
.
weight
,
self
.
summa_dim
,
out_shape
,
self
.
row_rank
,
self
.
col_rank
,
ParallelMode
.
PARALLEL_2D_ROW
,
ParallelMode
.
PARALLEL_2D_COL
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
if
self
.
bias
is
not
None
:
if
self
.
skip_bias_add
:
bias
=
Add_Bias_2D
.
apply
(
None
,
self
.
bias
,
self
.
hidden_size_per_partition
,
self
.
row_rank
,
self
.
col_rank
,
ParallelMode
.
PARALLEL_2D_ROW
,
ParallelMode
.
PARALLEL_2D_COL
,
True
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
return
output
,
bias
else
:
output
=
Add_Bias_2D
.
apply
(
output
,
self
.
bias
,
self
.
hidden_size_per_partition
,
self
.
row_rank
,
self
.
col_rank
,
ParallelMode
.
PARALLEL_2D_ROW
,
ParallelMode
.
PARALLEL_2D_COL
,
False
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
return
output
else
:
return
output
@
LAYERS
.
register_module
class
LayerNorm2D
(
ParallelLayer
):
r
"""Layer Normalization for 2D parallelism
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05
:type eps: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def
__init__
(
self
,
normalized_shape
:
int
,
eps
:
float
=
1e-05
,
dtype
=
None
):
super
().
__init__
()
# layer norm config
self
.
normalized_shape
=
normalized_shape
self
.
variance_epsilon
=
eps
# parallel setting
assert_summa_initialization
()
self
.
row_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_COL
)
self
.
col_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_ROW
)
self
.
summa_dim
=
get_summa_dim_from_env
()
# partitioning dimension
self
.
partitioned_partition
=
divide
(
normalized_shape
,
self
.
summa_dim
)
# create parameters
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
if
self
.
row_rank
==
0
:
self
.
gamma
=
Parameter
(
torch
.
ones
(
self
.
partitioned_partition
,
**
factory_kwargs
))
self
.
beta
=
Parameter
(
torch
.
zeros
(
self
.
partitioned_partition
,
**
factory_kwargs
))
else
:
self
.
gamma
=
Parameter
(
torch
.
tensor
(
1.0
,
requires_grad
=
True
,
**
factory_kwargs
))
self
.
beta
=
Parameter
(
torch
.
tensor
(
1.0
,
requires_grad
=
True
,
**
factory_kwargs
))
self
.
_set_tensor_parallel_attributes
()
def
_set_tensor_parallel_attributes
(
self
):
set_tensor_parallel_attribute
(
self
.
gamma
)
set_tensor_parallel_attribute
(
self
.
beta
)
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
with
torch
.
no_grad
():
E_x
=
torch
.
sum
(
x
,
dim
=-
1
,
keepdim
=
True
)
# [b/q, s, 1]
torch
.
distributed
.
all_reduce
(
E_x
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2D_ROW
))
E_x
/=
self
.
normalized_shape
# Var_x in the block below is the sum of input^2
Var_x
=
torch
.
sum
(
x
*
x
,
dim
=-
1
,
keepdim
=
True
)
# [b/q, s, 1]
torch
.
distributed
.
all_reduce
(
Var_x
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2D_ROW
))
Var_x
/=
self
.
normalized_shape
Var_x
=
Var_x
-
E_x
*
E_x
# variance of x [b/q, s, 1]
# this time 1/sqrt(Var_x + epsilon)
Var_x
=
1.0
/
torch
.
sqrt
(
Var_x
+
self
.
variance_epsilon
)
output
=
_LayerNorm_2D
.
apply
(
x
,
E_x
,
Var_x
,
self
.
normalized_shape
,
ParallelMode
.
PARALLEL_2D_ROW
,
ParallelMode
.
PARALLEL_2D_COL
)
bias
=
Add_Bias_2D
.
apply
(
None
,
self
.
beta
,
self
.
partitioned_partition
,
self
.
row_rank
,
self
.
col_rank
,
ParallelMode
.
PARALLEL_2D_ROW
,
ParallelMode
.
PARALLEL_2D_COL
,
True
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
scale
=
Add_Bias_2D
.
apply
(
None
,
self
.
gamma
,
self
.
partitioned_partition
,
self
.
row_rank
,
self
.
col_rank
,
ParallelMode
.
PARALLEL_2D_ROW
,
ParallelMode
.
PARALLEL_2D_COL
,
True
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
output
=
torch
.
addcmul
(
bias
,
scale
,
output
)
return
output
colossalai/nn/layer/parallel_2p5d/__init__.py
0 → 100644
View file @
404ecbdc
from
._operation
import
Matmul_AB_2p5D
,
Matmul_ABT_2p5D
,
Matmul_ATB_2p5D
,
Sum_2p5D
,
Add_Bias_2p5D
from
._transformer
import
TransformerMLP2p5D
,
TransformerSelfAttention2p5D
,
TransformerLayer2p5D
from
._vit
import
(
ViTMLP2p5D
,
ViTSelfAttention2p5D
,
ViTHead2p5D
,
ViTPatchEmbedding2p5D
,
ViTTokenFuser2p5D
,
ViTInputSplitter2p5D
)
from
.layers
import
Linear2p5D
,
LayerNorm2p5D
__all__
=
[
'Matmul_AB_2p5D'
,
'Matmul_ABT_2p5D'
,
'Matmul_ATB_2p5D'
,
'Sum_2p5D'
,
'Add_Bias_2p5D'
,
'TransformerMLP2p5D'
,
'TransformerSelfAttention2p5D'
,
'TransformerLayer2p5D'
,
'ViTMLP2p5D'
,
'ViTSelfAttention2p5D'
,
'ViTHead2p5D'
,
'ViTPatchEmbedding2p5D'
,
'ViTTokenFuser2p5D'
,
'ViTInputSplitter2p5D'
,
'Linear2p5D'
,
'LayerNorm2p5D'
]
colossalai/nn/layer/parallel_2p5d/_operation.py
0 → 100644
View file @
404ecbdc
from
typing
import
Any
,
Tuple
import
torch
import
torch.distributed
as
dist
from
torch
import
Tensor
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
get_current_device
,
empty_cache
def
get_parallel_group
(
parallel_mode
:
ParallelMode
):
return
gpc
.
get_group
(
parallel_mode
)
def
get_global_rank
():
return
gpc
.
get_global_rank
()
def
get_parallel_rank
(
parallel_mode
:
ParallelMode
):
return
gpc
.
get_local_rank
(
parallel_mode
)
class
Matmul_AB_2p5D
(
torch
.
autograd
.
Function
):
"""Matrix multiplication for :math:`C = AB`
"""
@
staticmethod
def
forward
(
ctx
:
Any
,
A
:
Tensor
,
B
:
Tensor
,
tesseract_dim
:
int
,
tesseract_dep
:
int
,
out_shape
:
Tuple
[
int
,
...],
row_rank
:
int
,
col_rank
:
int
,
dep_rank
:
int
,
row_parallel_mode
:
ParallelMode
,
col_parallel_mode
:
ParallelMode
,
dep_parallel_mode
:
ParallelMode
,
data_parallel_rank
:
int
,
pipeline_parallel_rank
:
int
,
pipeline_parallel_size
:
int
,
tensor_parallel_size
:
int
)
->
Tensor
:
# A: [b / dq, s, h / q] -> [(b * s) / dq, h / q]
# B: [h / dq, s / q]
# C: [b / dq, s, s / q] -> [(b * s) / dq, s / q]
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
2
],
\
'Invalid shapes: A={}, B={} for AB.'
.
format
(
A
.
shape
,
B
.
shape
)
empty_cache
()
if
ctx
:
ctx
.
save_for_backward
(
A
,
B
)
A_shape
=
A
.
shape
A
=
A
.
reshape
((
-
1
,
A_shape
[
-
1
]))
B_shape
=
B
.
shape
B
=
B
.
reshape
((
-
1
,
B_shape
[
-
1
]))
C_shape
=
(
A
.
shape
[
0
],
B
.
shape
[
-
1
])
C
=
torch
.
zeros
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_current_device
())
for
i
in
range
(
tesseract_dim
):
A_temp
=
A
.
clone
()
B_temp
=
B
.
clone
()
src_a
=
i
+
row_rank
*
tesseract_dim
+
dep_rank
*
(
tesseract_dim
**
2
)
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
broadcast
(
A_temp
,
src
=
src_a
,
group
=
get_parallel_group
(
row_parallel_mode
))
src_b
=
col_rank
+
i
*
tesseract_dim
+
dep_rank
*
(
tesseract_dim
**
2
)
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
broadcast
(
B_temp
,
src
=
src_b
,
group
=
get_parallel_group
(
col_parallel_mode
))
torch
.
addmm
(
C
,
A_temp
,
B_temp
,
out
=
C
)
out
=
C
.
reshape
(
out_shape
)
if
ctx
:
ctx
.
tesseract_dim
=
tesseract_dim
ctx
.
tesseract_dep
=
tesseract_dep
ctx
.
row_rank
=
row_rank
ctx
.
col_rank
=
col_rank
ctx
.
dep_rank
=
dep_rank
ctx
.
row_parallel_mode
=
row_parallel_mode
ctx
.
col_parallel_mode
=
col_parallel_mode
ctx
.
dep_parallel_mode
=
dep_parallel_mode
ctx
.
A_shape
=
A_shape
ctx
.
B_shape
=
B_shape
ctx
.
data_parallel_rank
=
data_parallel_rank
ctx
.
pipeline_parallel_rank
=
pipeline_parallel_rank
ctx
.
pipeline_parallel_size
=
pipeline_parallel_size
ctx
.
tensor_parallel_size
=
tensor_parallel_size
return
out
@
staticmethod
def
backward
(
ctx
:
Any
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
A
,
B
=
ctx
.
saved_tensors
A_grad
=
Matmul_ABT_2p5D
.
forward
(
None
,
output_grad
,
B
,
ctx
.
tesseract_dim
,
ctx
.
tesseract_dep
,
ctx
.
A_shape
,
ctx
.
row_rank
,
ctx
.
col_rank
,
ctx
.
dep_rank
,
ctx
.
row_parallel_mode
,
ctx
.
col_parallel_mode
,
ctx
.
dep_parallel_mode
,
ctx
.
data_parallel_rank
,
ctx
.
pipeline_parallel_rank
,
ctx
.
pipeline_parallel_size
,
ctx
.
tensor_parallel_size
)
B_grad
=
Matmul_ATB_2p5D
.
forward
(
None
,
A
,
output_grad
,
ctx
.
tesseract_dim
,
ctx
.
tesseract_dep
,
ctx
.
B_shape
,
ctx
.
row_rank
,
ctx
.
col_rank
,
ctx
.
dep_rank
,
ctx
.
row_parallel_mode
,
ctx
.
col_parallel_mode
,
ctx
.
dep_parallel_mode
,
ctx
.
data_parallel_rank
,
ctx
.
pipeline_parallel_rank
,
ctx
.
pipeline_parallel_size
,
ctx
.
tensor_parallel_size
)
return
A_grad
,
B_grad
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
Matmul_ABT_2p5D
(
torch
.
autograd
.
Function
):
"""Matrix multiplication for :math:`C = AB^T`
"""
@
staticmethod
def
forward
(
ctx
:
Any
,
A
:
Tensor
,
B
:
Tensor
,
tesseract_dim
:
int
,
tesseract_dep
:
int
,
out_shape
:
Tuple
[
int
,
...],
row_rank
:
int
,
col_rank
:
int
,
dep_rank
:
int
,
row_parallel_mode
:
ParallelMode
,
col_parallel_mode
:
ParallelMode
,
dep_parallel_mode
:
ParallelMode
,
data_parallel_rank
:
int
,
pipeline_parallel_rank
:
int
,
pipeline_parallel_size
:
int
,
tensor_parallel_size
:
int
)
->
Tensor
:
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
],
\
'Invalid shapes: A={}, B={} for ABT.'
.
format
(
A
.
shape
,
B
.
shape
)
empty_cache
()
if
ctx
:
ctx
.
save_for_backward
(
A
,
B
)
A_shape
=
A
.
shape
A
=
A
.
reshape
((
-
1
,
A_shape
[
-
1
]))
B_shape
=
B
.
shape
B
=
B
.
reshape
((
-
1
,
B_shape
[
-
1
]))
C_shape
=
(
A
.
shape
[
0
],
B
.
shape
[
0
])
C
=
torch
.
empty
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_current_device
())
for
i
in
range
(
tesseract_dim
):
B_temp
=
B
.
clone
()
src_b
=
col_rank
+
i
*
tesseract_dim
+
dep_rank
*
(
tesseract_dim
**
2
)
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
broadcast
(
B_temp
,
src
=
src_b
,
group
=
gpc
.
get_group
(
col_parallel_mode
))
C_temp
=
torch
.
matmul
(
A
,
B_temp
.
transpose
(
0
,
1
))
src_c
=
i
+
row_rank
*
tesseract_dim
+
dep_rank
*
(
tesseract_dim
**
2
)
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
reduce
(
C_temp
,
dst
=
src_c
,
group
=
gpc
.
get_group
(
row_parallel_mode
))
if
i
==
col_rank
:
C
=
C_temp
.
clone
()
out
=
C
.
reshape
(
out_shape
)
if
ctx
:
ctx
.
tesseract_dim
=
tesseract_dim
ctx
.
tesseract_dep
=
tesseract_dep
ctx
.
row_rank
=
row_rank
ctx
.
col_rank
=
col_rank
ctx
.
dep_rank
=
dep_rank
ctx
.
row_parallel_mode
=
row_parallel_mode
ctx
.
col_parallel_mode
=
col_parallel_mode
ctx
.
dep_parallel_mode
=
dep_parallel_mode
ctx
.
A_shape
=
A_shape
ctx
.
B_shape
=
B_shape
ctx
.
data_parallel_rank
=
data_parallel_rank
ctx
.
pipeline_parallel_rank
=
pipeline_parallel_rank
ctx
.
pipeline_parallel_size
=
pipeline_parallel_size
ctx
.
tensor_parallel_size
=
tensor_parallel_size
return
out
@
staticmethod
def
backward
(
ctx
:
Any
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
A
,
B
=
ctx
.
saved_tensors
A_grad
=
Matmul_AB_2p5D
.
forward
(
None
,
output_grad
,
B
,
ctx
.
tesseract_dim
,
ctx
.
tesseract_dep
,
ctx
.
A_shape
,
ctx
.
row_rank
,
ctx
.
col_rank
,
ctx
.
dep_rank
,
ctx
.
row_parallel_mode
,
ctx
.
col_parallel_mode
,
ctx
.
dep_parallel_mode
,
ctx
.
data_parallel_rank
,
ctx
.
pipeline_parallel_rank
,
ctx
.
pipeline_parallel_size
,
ctx
.
tensor_parallel_size
)
B_grad
=
Matmul_ATB_2p5D
.
forward
(
None
,
output_grad
,
A
,
ctx
.
tesseract_dim
,
ctx
.
tesseract_dep
,
ctx
.
B_shape
,
ctx
.
row_rank
,
ctx
.
col_rank
,
ctx
.
dep_rank
,
ctx
.
row_parallel_mode
,
ctx
.
col_parallel_mode
,
ctx
.
dep_parallel_mode
,
ctx
.
data_parallel_rank
,
ctx
.
pipeline_parallel_rank
,
ctx
.
pipeline_parallel_size
,
ctx
.
tensor_parallel_size
)
return
A_grad
,
B_grad
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
Matmul_ATB_2p5D
(
torch
.
autograd
.
Function
):
"""Matrix multiplication for :math:`C = A^TB`
"""
@
staticmethod
def
forward
(
ctx
:
Any
,
A
:
Tensor
,
B
:
Tensor
,
tesseract_dim
:
int
,
tesseract_dep
:
int
,
out_shape
:
Tuple
[
int
,
...],
row_rank
:
int
,
col_rank
:
int
,
dep_rank
:
int
,
row_parallel_mode
:
ParallelMode
,
col_parallel_mode
:
ParallelMode
,
dep_parallel_mode
:
ParallelMode
,
data_parallel_rank
:
int
,
pipeline_parallel_rank
:
int
,
pipeline_parallel_size
:
int
,
tensor_parallel_size
:
int
):
assert
A
.
shape
[
-
2
]
==
B
.
shape
[
-
2
],
\
'Invalid shapes: A={}, B={} for ATB.'
.
format
(
A
.
shape
,
B
.
shape
)
empty_cache
()
if
ctx
:
ctx
.
save_for_backward
(
A
,
B
)
A_shape
=
A
.
shape
A
=
A
.
reshape
((
-
1
,
A_shape
[
-
1
]))
B_shape
=
B
.
shape
B
=
B
.
reshape
((
-
1
,
B_shape
[
-
1
]))
C_shape
=
(
A
.
shape
[
-
1
],
B
.
shape
[
-
1
])
C
=
torch
.
empty
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_current_device
())
for
i
in
range
(
tesseract_dim
):
A_temp
=
A
.
clone
()
src_a
=
i
+
row_rank
*
tesseract_dim
+
dep_rank
*
(
tesseract_dim
**
2
)
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
broadcast
(
A_temp
,
src
=
src_a
,
group
=
get_parallel_group
(
row_parallel_mode
))
C_temp
=
torch
.
matmul
(
A_temp
.
transpose
(
0
,
1
),
B
)
src_c
=
col_rank
+
i
*
tesseract_dim
+
dep_rank
*
(
tesseract_dim
**
2
)
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
reduce
(
C_temp
,
dst
=
src_c
,
group
=
get_parallel_group
(
col_parallel_mode
))
if
i
==
row_rank
:
C
=
C_temp
.
clone
()
out
=
C
.
reshape
(
out_shape
)
if
ctx
:
ctx
.
tesseract_dim
=
tesseract_dim
ctx
.
tesseract_dep
=
tesseract_dep
ctx
.
row_rank
=
row_rank
ctx
.
col_rank
=
col_rank
ctx
.
dep_rank
=
dep_rank
ctx
.
row_parallel_mode
=
row_parallel_mode
ctx
.
col_parallel_mode
=
col_parallel_mode
ctx
.
dep_parallel_mode
=
dep_parallel_mode
ctx
.
A_shape
=
A_shape
ctx
.
B_shape
=
B_shape
ctx
.
data_parallel_rank
=
data_parallel_rank
ctx
.
pipeline_parallel_rank
=
pipeline_parallel_rank
ctx
.
pipeline_parallel_size
=
pipeline_parallel_size
ctx
.
tensor_parallel_size
=
tensor_parallel_size
return
out
@
staticmethod
def
backward
(
ctx
:
Any
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
A
,
B
=
ctx
.
saved_tensors
A_grad
=
Matmul_ABT_2p5D
.
forward
(
None
,
B
,
output_grad
,
ctx
.
tesseract_dim
,
ctx
.
tesseract_dep
,
ctx
.
A_shape
,
ctx
.
row_rank
,
ctx
.
col_rank
,
ctx
.
dep_rank
,
ctx
.
row_parallel_mode
,
ctx
.
col_parallel_mode
,
ctx
.
dep_parallel_mode
,
ctx
.
data_parallel_rank
,
ctx
.
pipeline_parallel_rank
,
ctx
.
pipeline_parallel_size
,
ctx
.
tensor_parallel_size
)
B_grad
=
Matmul_AB_2p5D
.
forward
(
None
,
A
,
output_grad
,
ctx
.
tesseract_dim
,
ctx
.
tesseract_dep
,
ctx
.
B_shape
,
ctx
.
row_rank
,
ctx
.
col_rank
,
ctx
.
dep_rank
,
ctx
.
row_parallel_mode
,
ctx
.
col_parallel_mode
,
ctx
.
dep_parallel_mode
,
ctx
.
data_parallel_rank
,
ctx
.
pipeline_parallel_rank
,
ctx
.
pipeline_parallel_size
,
ctx
.
tensor_parallel_size
)
return
A_grad
,
B_grad
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
Add_Bias_2p5D
(
torch
.
autograd
.
Function
):
"""Matrix add bias: :math:`C = A + b`
"""
@
staticmethod
def
forward
(
ctx
:
Any
,
input
:
Tensor
,
bias
:
Tensor
,
output_size_per_partition
:
int
,
tesseract_dim
:
int
,
tesseract_dep
:
int
,
row_rank
:
int
,
col_rank
:
int
,
dep_rank
:
int
,
row_parallel_mode
:
ParallelMode
,
col_parallel_mode
:
ParallelMode
,
dep_parallel_mode
:
ParallelMode
,
skip_bias_add
:
bool
,
data_parallel_rank
:
int
,
pipeline_parallel_rank
:
int
,
pipeline_parallel_size
:
int
,
tensor_parallel_size
:
int
)
->
Tensor
:
if
row_rank
==
0
:
bias_temp
=
bias
.
clone
()
else
:
bias_temp
=
torch
.
zeros
(
output_size_per_partition
,
dtype
=
bias
.
dtype
,
device
=
get_current_device
())
src_rank
=
col_rank
+
dep_rank
*
(
tesseract_dim
**
2
)
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
broadcast
(
bias_temp
,
src
=
src_rank
,
group
=
get_parallel_group
(
col_parallel_mode
))
ctx
.
row_rank
=
row_rank
ctx
.
col_rank
=
col_rank
ctx
.
dep_rank
=
dep_rank
ctx
.
tesseract_dim
=
tesseract_dim
ctx
.
tesseract_dep
=
tesseract_dep
ctx
.
row_parallel_mode
=
row_parallel_mode
ctx
.
col_parallel_mode
=
col_parallel_mode
ctx
.
dep_parallel_mode
=
dep_parallel_mode
ctx
.
bias
=
skip_bias_add
ctx
.
data_parallel_rank
=
data_parallel_rank
ctx
.
pipeline_parallel_rank
=
pipeline_parallel_rank
ctx
.
pipeline_parallel_size
=
pipeline_parallel_size
ctx
.
tensor_parallel_size
=
tensor_parallel_size
if
skip_bias_add
:
return
bias_temp
else
:
output
=
input
+
bias_temp
return
output
@
staticmethod
def
backward
(
ctx
,
output_grad
):
row_rank
=
ctx
.
row_rank
col_rank
=
ctx
.
col_rank
dep_rank
=
ctx
.
dep_rank
tesseract_dim
=
ctx
.
tesseract_dim
tesseract_dep
=
ctx
.
tesseract_dep
row_parallel_mode
=
ctx
.
row_parallel_mode
col_parallel_mode
=
ctx
.
col_parallel_mode
dep_parallel_mode
=
ctx
.
dep_parallel_mode
data_parallel_rank
=
ctx
.
data_parallel_rank
pipeline_parallel_rank
=
ctx
.
pipeline_parallel_rank
pipeline_parallel_size
=
ctx
.
pipeline_parallel_size
tensor_parallel_size
=
ctx
.
tensor_parallel_size
if
ctx
.
bias
:
dst_rank
=
col_rank
+
dep_rank
*
(
tesseract_dim
**
2
)
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
reduce
(
output_grad
,
dst
=
dst_rank
,
group
=
get_parallel_group
(
col_parallel_mode
))
if
row_rank
==
0
:
return
None
,
output_grad
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
else
:
grad_tmp
=
torch
.
zeros_like
(
output_grad
)
return
None
,
grad_tmp
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
else
:
reduce_dim
=
tuple
(
range
(
output_grad
.
ndim
-
1
))
reduce
=
torch
.
sum
(
output_grad
,
dim
=
reduce_dim
)
dst_rank
=
col_rank
+
dep_rank
*
(
tesseract_dim
**
2
)
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
reduce
(
reduce
,
dst
=
dst_rank
,
group
=
get_parallel_group
(
col_parallel_mode
))
if
row_rank
==
0
:
return
output_grad
,
reduce
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
else
:
reduce_tmp
=
torch
.
zeros_like
(
reduce
)
return
output_grad
,
reduce_tmp
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
_LayerNorm_2p5D
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
:
Any
,
input
:
Tensor
,
E_x
:
Tensor
,
Var_x
:
Tensor
,
hidden_size
:
int
,
row_parallel_mode
:
ParallelMode
,
col_parallel_mode
:
ParallelMode
,
dep_parallel_mode
:
ParallelMode
)
->
Tensor
:
input
=
input
-
E_x
# in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps)
ctx
.
hidden_size
=
hidden_size
output
=
input
*
Var_x
ctx
.
save_for_backward
(
output
,
Var_x
)
ctx
.
row_parallel_mode
=
row_parallel_mode
ctx
.
col_parallel_mode
=
col_parallel_mode
ctx
.
dep_parallel_mode
=
dep_parallel_mode
return
output
@
staticmethod
def
backward
(
ctx
,
output_grad
):
row_parallel_mode
=
ctx
.
row_parallel_mode
col_parallel_mode
=
ctx
.
col_parallel_mode
dep_parallel_mode
=
ctx
.
dep_parallel_mode
x
,
Var_x
=
ctx
.
saved_tensors
# in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x
with
torch
.
no_grad
():
output_grad_sum
=
torch
.
sum
(
output_grad
,
dim
=-
1
,
keepdim
=
True
)
torch
.
distributed
.
all_reduce
(
output_grad_sum
,
group
=
get_parallel_group
(
row_parallel_mode
))
output_grad_sum
/=
ctx
.
hidden_size
output_grad_mul_x_sum
=
torch
.
sum
(
output_grad
*
x
,
dim
=-
1
,
keepdim
=
True
)
torch
.
distributed
.
all_reduce
(
output_grad_mul_x_sum
,
group
=
get_parallel_group
(
row_parallel_mode
))
output_grad_mul_x_sum
/=
ctx
.
hidden_size
input_grad
=
output_grad
.
clone
()
input_grad
-=
x
*
output_grad_mul_x_sum
input_grad
-=
output_grad_sum
input_grad
*=
Var_x
return
input_grad
,
None
,
None
,
None
,
None
,
None
,
None
class
Sum_2p5D
(
torch
.
autograd
.
Function
):
"""Compute the sum of input tensors
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
dim
,
tesseract_dim
,
row_parallel_mode
,
keepdim
=
False
):
# input: [b/q, s, h/q]
empty_cache
()
ctx
.
save_for_backward
(
inputs
)
# sum: [b/q, s]
out
=
torch
.
sum
(
inputs
,
dim
=
dim
,
keepdim
=
keepdim
)
torch
.
distributed
.
all_reduce
(
out
,
group
=
gpc
.
get_group
(
row_parallel_mode
))
return
out
@
staticmethod
def
backward
(
ctx
,
output_grad
):
with
torch
.
no_grad
():
inputs
=
ctx
.
saved_tensors
input_grad
=
torch
.
ones
(
inputs
.
shape
,
dtype
=
output_grad
.
dtype
)
return
input_grad
,
None
,
None
,
None
,
None
,
None
class
_ViT_Split_2p5D
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
inputs
,
batch_size
,
tesseract_dim
,
tesseract_dep
,
xz_parallel_mode
):
# inputs: [b, s, h/q]
# output: [b/dq, s, h/q]
empty_cache
()
ctx
.
batch_size
=
batch_size
ctx
.
tesseract_dim
=
tesseract_dim
ctx
.
tesseract_dep
=
tesseract_dep
ctx
.
xz_parallel_mode
=
xz_parallel_mode
xz_rank
=
gpc
.
get_local_rank
(
xz_parallel_mode
)
output
=
torch
.
chunk
(
inputs
,
tesseract_dep
*
tesseract_dim
,
dim
=
0
)[
xz_rank
]
output
=
output
.
clone
()
return
output
@
staticmethod
def
backward
(
ctx
,
output_grad
):
# output_grad: [b/dq, s, h/q]
# grads: [b, s, h/q]
# *
grads_shape
=
(
ctx
.
batch_size
,)
+
output_grad
.
shape
[
1
:]
grads
=
torch
.
empty
(
grads_shape
,
dtype
=
output_grad
.
dtype
,
device
=
get_current_device
())
dist
.
all_gather
(
list
(
grads
.
chunk
(
ctx
.
tesseract_dim
*
ctx
.
tesseract_dep
,
dim
=
0
)),
output_grad
.
contiguous
(),
group
=
get_parallel_group
(
ctx
.
xz_parallel_mode
))
return
grads
,
None
,
None
,
None
,
None
colossalai/nn/layer/parallel_2p5d/_transformer.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
math
import
torch
from
torch
import
nn
as
nn
,
Tensor
from
colossalai.nn.layer._common_utils
import
divide
from
colossalai.registry
import
LAYERS
from
._utils
import
assert_tesseract_initialization
,
\
get_tesseract_dim_dep_from_env
from
.layers
import
Linear2p5D
,
LayerNorm2p5D
from
.._common_utils
import
ACT2FN
@
LAYERS
.
register_module
class
TransformerMLP2p5D
(
nn
.
Module
):
"""
MLP will take the input with h hidden state, project it to mlp_ratio * h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
:param in_features: the size of input tensor
:type in_features: int
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: int, optional
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param dropout_prob: dropout probability, defaults to 0.
:type dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def
__init__
(
self
,
in_features
:
int
,
mlp_ratio
:
int
,
act_func
:
str
=
'gelu'
,
dropout_prob
:
float
=
0.
,
dtype
=
None
,
):
super
().
__init__
()
assert_tesseract_initialization
()
self
.
tesseract_dim
,
self
.
tesseract_dep
=
get_tesseract_dim_dep_from_env
()
self
.
in_features
=
in_features
# Project to h * mlp_ratio.
self
.
dense_1
=
Linear2p5D
(
in_features
,
mlp_ratio
*
in_features
,
dtype
=
dtype
)
assert
act_func
in
ACT2FN
.
keys
(),
f
'Invalid value for argument act_func, '
\
f
'activation function can only be
{
list
(
ACT2FN
.
keys
())
}
'
self
.
activation_func
=
ACT2FN
[
act_func
]
# Project back to h.
self
.
dense_2
=
Linear2p5D
(
mlp_ratio
*
in_features
,
in_features
,
dtype
=
dtype
)
self
.
dropout
=
nn
.
Dropout
(
dropout_prob
)
self
.
layernorm
=
LayerNorm2p5D
(
in_features
,
dtype
=
dtype
)
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
intermediate_output
=
self
.
dense_1
(
x
)
intermediate_output
=
self
.
activation_func
(
intermediate_output
)
output
=
self
.
dense_2
(
intermediate_output
)
output
=
self
.
dropout
(
output
)
output
=
self
.
layernorm
(
x
+
output
)
return
output
@
LAYERS
.
register_module
class
TransformerSelfAttention2p5D
(
nn
.
Module
):
"""Self attention layer for 2.5D parallel Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layer
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layer
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def
__init__
(
self
,
hidden_size
,
num_attention_heads
,
attention_dropout_prob
,
hidden_dropout_prob
,
dtype
=
None
,
):
super
().
__init__
()
assert_tesseract_initialization
()
self
.
tesseract_dim
,
self
.
tesseract_dep
=
get_tesseract_dim_dep_from_env
()
self
.
hidden_size
=
hidden_size
self
.
num_attention_heads
=
divide
(
num_attention_heads
,
self
.
tesseract_dim
)
# *
self
.
attention_head_size
=
divide
(
hidden_size
,
num_attention_heads
)
self
.
all_head_size
=
self
.
num_attention_heads
*
self
.
attention_head_size
self
.
query_key_value
=
Linear2p5D
(
hidden_size
,
3
*
hidden_size
,
dtype
=
dtype
,
)
self
.
attention_dropout
=
nn
.
Dropout
(
attention_dropout_prob
)
self
.
dense
=
Linear2p5D
(
hidden_size
,
hidden_size
,
dtype
=
dtype
,
)
self
.
dropout
=
nn
.
Dropout
(
hidden_dropout_prob
)
self
.
layernorm
=
LayerNorm2p5D
(
hidden_size
,
dtype
=
dtype
)
def
forward
(
self
,
hidden_states
:
Tensor
,
attention_mask
:
Tensor
)
->
Tensor
:
query_key_value
=
self
.
query_key_value
(
hidden_states
)
new_qkv_shape
=
query_key_value
.
shape
[:
-
1
]
+
\
(
self
.
num_attention_heads
,
3
*
self
.
attention_head_size
)
query_key_value
=
query_key_value
.
view
(
new_qkv_shape
)
query_key_value
=
query_key_value
.
permute
((
0
,
2
,
1
,
3
))
query_layer
,
key_layer
,
value_layer
=
torch
.
chunk
(
query_key_value
,
3
,
dim
=-
1
)
attention_scores
=
torch
.
matmul
(
query_layer
,
key_layer
.
transpose
(
-
1
,
-
2
))
attention_scores
=
attention_scores
/
\
math
.
sqrt
(
self
.
attention_head_size
)
attention_scores
=
attention_scores
+
attention_mask
attention_probs
=
nn
.
Softmax
(
dim
=-
1
)(
attention_scores
)
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
context_layer
=
torch
.
matmul
(
attention_probs
,
value_layer
)
context_layer
=
context_layer
.
permute
((
0
,
2
,
1
,
3
)).
contiguous
()
new_context_layer_shape
=
context_layer
.
size
()[
:
-
2
]
+
(
self
.
all_head_size
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
output
=
self
.
dense
(
context_layer
)
output
=
self
.
dropout
(
output
)
attention_output
=
self
.
layernorm
(
hidden_states
+
output
)
return
attention_output
@
LAYERS
.
register_module
class
TransformerLayer2p5D
(
nn
.
Module
):
"""Transformer layer which contains a self-attention layer and a MLP layer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: float, optional
:param attention_dropout_prob: dropout probability for attention layer, defaults to 0.
:type attention_dropout_prob: float, optional
:param hidden_dropout_prob: dropout probability for attention layer, defaults to 0.
:type hidden_dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def
__init__
(
self
,
hidden_size
,
num_attention_heads
,
act_func
=
'gelu'
,
mlp_ratio
=
4
,
attention_dropout_prob
:
float
=
0.
,
hidden_dropout_prob
:
float
=
0.
,
dtype
=
None
,
):
super
().
__init__
()
self
.
attention
=
TransformerSelfAttention2p5D
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_attention_heads
,
attention_dropout_prob
=
attention_dropout_prob
,
hidden_dropout_prob
=
hidden_dropout_prob
,
dtype
=
dtype
,
)
self
.
mlp
=
TransformerMLP2p5D
(
in_features
=
hidden_size
,
dropout_prob
=
hidden_dropout_prob
,
act_func
=
act_func
,
mlp_ratio
=
mlp_ratio
,
dtype
=
dtype
,
)
def
forward
(
self
,
hidden_states
:
Tensor
,
attention_mask
:
Tensor
)
->
Tensor
:
attention_output
=
self
.
attention
(
hidden_states
,
attention_mask
)
output
=
self
.
mlp
(
attention_output
)
return
output
colossalai/nn/layer/parallel_2p5d/_utils.py
0 → 100644
View file @
404ecbdc
import
os
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
def
get_tesseract_dim_dep_from_env
():
try
:
tesseract_dim
=
int
(
os
.
environ
[
'TESSERACT_DIM'
])
tesseract_dep
=
int
(
os
.
environ
[
'TESSERACT_DEP'
])
assert
tesseract_dim
>
0
,
'TESSERACT_DIM must be larger than zero'
assert
tesseract_dep
>
0
,
'TESSERACT_DEP must be larger than zero'
return
tesseract_dim
,
tesseract_dep
except
KeyError
as
e
:
raise
EnvironmentError
(
'TESSERACT_DIM or TESSERACT_DEP is not found in the current environment, '
'please make sure that you have used the correct process group initializer'
)
def
assert_tesseract_initialization
():
assert
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_2P5D_COL
)
and
\
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_2P5D_ROW
)
and
\
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_2P5D_DEP
)
and
\
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_2P5D_XZ
),
\
'Both PARALLEL_2P5D_COL, PARALLEL_2P5D_ROW, PARALLEL_2P5D_DEP and PARALLEL_2P5D_XZ must be initialized by the process group initializer'
colossalai/nn/layer/parallel_2p5d/_vit.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
math
import
torch
from
torch
import
nn
as
nn
,
Tensor
,
distributed
as
dist
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.nn.layer.vanilla_vision_transformer.layers
import
to_2tuple
from
colossalai.registry
import
LAYERS
from
colossalai.utils
import
get_current_device
from
._operation
import
_ViT_Split_2p5D
from
._utils
import
assert_tesseract_initialization
,
\
get_tesseract_dim_dep_from_env
from
.layers
import
Linear2p5D
from
.._common_utils
import
ACT2FN
,
divide
,
CheckpointModule
from
.._common_utils
import
set_tensor_parallel_attribute
@
LAYERS
.
register_module
class
ViTMLP2p5D
(
CheckpointModule
):
"""MLP layer for 2.5D parallel Vision Transformer
:param in_features: size of each input sample
:type in_features: int
:param mlp_ratio: hidden size of MLP divided by embedding dim
:type mlp_ratio: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param dropout_prob: dropout probability, defaults to 0.
:type dropout_prob: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: If set to `True`, activation checkpoint is used, defaults to `False`
:type checkpoint: bool, optional
"""
def
__init__
(
self
,
in_features
:
int
,
mlp_ratio
:
int
,
act_func
:
str
=
'gelu'
,
dropout_prob
:
float
=
0.
,
dtype
=
None
,
checkpoint
:
bool
=
False
):
super
().
__init__
(
checkpoint
=
checkpoint
)
assert_tesseract_initialization
()
self
.
in_features
=
in_features
self
.
mlp_ratio
=
mlp_ratio
# Project to mlp_ratio * h.
self
.
dense_1
=
Linear2p5D
(
self
.
in_features
,
self
.
mlp_ratio
*
self
.
in_features
,
dtype
=
dtype
,
)
self
.
act
=
ACT2FN
[
act_func
]
# Project back to h.
self
.
dense_2
=
Linear2p5D
(
self
.
mlp_ratio
*
self
.
in_features
,
self
.
in_features
,
dtype
=
dtype
,
)
self
.
dropout
=
nn
.
Dropout
(
dropout_prob
)
def
_forward
(
self
,
hidden_states
:
Tensor
)
->
Tensor
:
intermediate_output
=
self
.
dense_1
(
hidden_states
)
intermediate_output
=
self
.
act
(
intermediate_output
)
intermediate_output
=
self
.
dropout
(
intermediate_output
)
output
=
self
.
dense_2
(
intermediate_output
)
output
=
self
.
dropout
(
output
)
return
output
@
LAYERS
.
register_module
class
ViTSelfAttention2p5D
(
CheckpointModule
):
"""Self-attention layer for 2.5D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layers
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layers
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: If set to `True`, activation checkpoint is used, defaults to `False`
:type checkpoint: bool, optional
"""
def
__init__
(
self
,
hidden_size
,
num_attention_heads
,
attention_dropout_prob
,
hidden_dropout_prob
,
dtype
=
None
,
checkpoint
:
bool
=
False
):
super
().
__init__
(
checkpoint
=
checkpoint
)
assert_tesseract_initialization
()
self
.
tesseract_dim
,
self
.
tesseract_dep
=
get_tesseract_dim_dep_from_env
()
self
.
hidden_size
=
hidden_size
self
.
num_attention_heads
=
divide
(
num_attention_heads
,
self
.
tesseract_dim
)
# *
self
.
attention_head_size
=
divide
(
hidden_size
,
num_attention_heads
)
self
.
all_head_size
=
self
.
num_attention_heads
*
self
.
attention_head_size
self
.
query_key_value
=
Linear2p5D
(
hidden_size
,
3
*
hidden_size
,
dtype
=
dtype
,
)
self
.
attention_dropout
=
nn
.
Dropout
(
attention_dropout_prob
)
self
.
dense
=
Linear2p5D
(
hidden_size
,
hidden_size
,
dtype
=
dtype
,
)
self
.
dropout
=
nn
.
Dropout
(
hidden_dropout_prob
)
def
_forward
(
self
,
hidden_states
:
Tensor
)
->
Tensor
:
query_key_value
=
self
.
query_key_value
(
hidden_states
)
new_qkv_shape
=
query_key_value
.
shape
[:
-
1
]
+
\
(
self
.
num_attention_heads
,
3
*
self
.
attention_head_size
)
query_key_value
=
query_key_value
.
view
(
new_qkv_shape
)
query_key_value
=
query_key_value
.
permute
((
0
,
2
,
1
,
3
))
query_layer
,
key_layer
,
value_layer
=
torch
.
chunk
(
query_key_value
,
3
,
dim
=-
1
)
attention_scores
=
torch
.
matmul
(
query_layer
,
key_layer
.
transpose
(
-
1
,
-
2
))
attention_scores
=
attention_scores
/
\
math
.
sqrt
(
self
.
attention_head_size
)
attention_probs
=
nn
.
Softmax
(
dim
=-
1
)(
attention_scores
)
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
context_layer
=
torch
.
matmul
(
attention_probs
,
value_layer
)
context_layer
=
context_layer
.
transpose
(
1
,
2
)
new_context_layer_shape
=
context_layer
.
size
()[
:
-
2
]
+
(
self
.
all_head_size
,)
context_layer
=
context_layer
.
reshape
(
new_context_layer_shape
)
output
=
self
.
dense
(
context_layer
)
output
=
self
.
dropout
(
output
)
return
output
@
LAYERS
.
register_module
class
ViTHead2p5D
(
nn
.
Module
):
"""Output layer for 2.5D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_classes: number of classes
:type num_classes: int
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def
__init__
(
self
,
hidden_size
,
num_classes
,
dtype
=
None
,
):
super
().
__init__
()
assert_tesseract_initialization
()
self
.
linear
=
Linear2p5D
(
hidden_size
,
num_classes
,
dtype
=
dtype
,
)
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
x
=
x
[:,
0
]
x
=
self
.
linear
(
x
)
return
x
@
LAYERS
.
register_module
class
ViTPatchEmbedding2p5D
(
nn
.
Module
):
""" 2.5D Image to Patch Embedding
:param img_size: iamge size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param in_chans: number of channels of input image, defaults to 3
:type in_chans: int, optional
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
"""
def
__init__
(
self
,
img_size
,
patch_size
,
embed_dim
,
in_chans
=
3
,
flatten
=
True
):
super
().
__init__
()
img_size
=
to_2tuple
(
img_size
)
patch_size
=
to_2tuple
(
patch_size
)
assert_tesseract_initialization
()
self
.
tesseract_dim
,
self
.
tesseract_dep
=
get_tesseract_dim_dep_from_env
()
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
grid_size
=
(
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
])
self
.
num_patches
=
self
.
grid_size
[
0
]
*
self
.
grid_size
[
1
]
self
.
flatten
=
flatten
self
.
embed_dim
=
embed_dim
//
self
.
tesseract_dim
# *
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
self
.
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
,
)
# move self to cuda before sync
self
.
to
(
get_current_device
())
# sync
self
.
_broadcast_conv_params
()
self
.
proj
.
weight
.
register_hook
(
self
.
_sync_grad_during_backward
)
self
.
proj
.
bias
.
register_hook
(
self
.
_sync_grad_during_backward
)
def
_broadcast_conv_params
(
self
)
->
None
:
xz_rank
=
gpc
.
get_ranks_in_group
(
ParallelMode
.
PARALLEL_2P5D_XZ
)
dist
.
broadcast
(
self
.
proj
.
weight
,
src
=
xz_rank
[
0
],
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2P5D_XZ
))
dist
.
broadcast
(
self
.
proj
.
bias
,
src
=
xz_rank
[
0
],
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2P5D_XZ
))
def
_sync_grad_during_backward
(
self
,
grad
:
Tensor
)
->
None
:
dist
.
all_reduce
(
grad
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2P5D_XZ
))
grad
=
grad
/
self
.
tesseract_dim
/
self
.
tesseract_dep
# *
return
grad
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
B
,
C
,
H
,
W
=
x
.
shape
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
],
\
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
x
=
self
.
proj
(
x
)
if
self
.
flatten
:
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
# BCHW -> BNC
return
x
@
LAYERS
.
register_module
class
ViTTokenFuser2p5D
(
nn
.
Module
):
"""
Fuse cls token and pos embedding to the input
:param img_size: image size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param drop_rate: dropout probability, defaults to 0.
:type drop_rate: float, optional
"""
def
__init__
(
self
,
img_size
,
patch_size
,
embed_dim
,
drop_rate
=
0.
):
super
().
__init__
()
img_size
=
to_2tuple
(
img_size
)
patch_size
=
to_2tuple
(
patch_size
)
assert_tesseract_initialization
()
self
.
tesseract_dim
,
self
.
tesseract_dep
=
get_tesseract_dim_dep_from_env
()
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
grid_size
=
(
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
])
self
.
num_patches
=
self
.
grid_size
[
0
]
*
self
.
grid_size
[
1
]
self
.
embed_dim
=
embed_dim
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
self
.
embed_dim
//
self
.
tesseract_dim
))
# *
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
self
.
num_patches
+
1
,
self
.
embed_dim
//
self
.
tesseract_dim
))
# *
# move to cuda before broadcast
self
.
to
(
get_current_device
())
self
.
_broadcast_params
()
self
.
cls_token
.
register_hook
(
self
.
_sync_grad_hook
)
self
.
pos_embed
.
register_hook
(
self
.
_sync_grad_hook
)
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
self
.
_set_tensor_parallel_attribute
()
def
_set_tensor_parallel_attribute
(
self
):
set_tensor_parallel_attribute
(
self
.
cls_token
)
set_tensor_parallel_attribute
(
self
.
pos_embed
)
def
_broadcast_params
(
self
)
->
None
:
" broadcast to all column ranks for data consistency "
xz_rank
=
gpc
.
get_ranks_in_group
(
ParallelMode
.
PARALLEL_2P5D_XZ
)
dist
.
broadcast
(
self
.
cls_token
,
src
=
xz_rank
[
0
],
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2P5D_XZ
))
dist
.
broadcast
(
self
.
pos_embed
,
src
=
xz_rank
[
0
],
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2P5D_XZ
))
def
_sync_grad_hook
(
self
,
grad
)
->
None
:
dist
.
all_reduce
(
grad
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2P5D_XZ
))
grad
=
grad
/
self
.
tesseract_dim
/
self
.
tesseract_dep
# *
return
grad
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
# stole cls_tokens impl from Phil Wang, thanks
cls_token
=
self
.
cls_token
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
)
x
=
torch
.
cat
((
cls_token
,
x
),
dim
=
1
)
x
=
self
.
pos_drop
(
x
+
self
.
pos_embed
)
return
x
@
LAYERS
.
register_module
class
ViTInputSplitter2p5D
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
assert_tesseract_initialization
()
self
.
tesseract_dim
,
self
.
tesseract_dep
=
get_tesseract_dim_dep_from_env
()
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
batch_size
=
x
.
size
(
0
)
return
_ViT_Split_2p5D
.
apply
(
x
,
batch_size
,
self
.
tesseract_dim
,
self
.
tesseract_dep
,
ParallelMode
.
PARALLEL_2P5D_XZ
,
)
colossalai/nn/layer/parallel_2p5d/layers.py
0 → 100644
View file @
404ecbdc
import
math
import
torch
from
torch
import
Tensor
from
torch.nn
import
Parameter
,
init
as
init
from
colossalai.context
import
seed
,
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.registry
import
LAYERS
from
colossalai.utils
import
get_current_device
from
._operation
import
Matmul_AB_2p5D
,
Add_Bias_2p5D
,
_LayerNorm_2p5D
from
._utils
import
get_tesseract_dim_dep_from_env
,
assert_tesseract_initialization
from
.._common_utils
import
divide
,
set_tensor_parallel_attribute
from
..base_layer
import
ParallelLayer
@
LAYERS
.
register_module
class
Linear2p5D
(
ParallelLayer
):
"""Linear layer for 2.5D parallelism
:param in_features: size of each input sample
:type in_features: int
:param out_features: size of each output sample
:type out_features: int
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
bias
:
bool
=
True
,
dtype
=
None
,
skip_bias_add
:
bool
=
False
):
super
().
__init__
()
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
skip_bias_add
=
skip_bias_add
# parallel setting
assert_tesseract_initialization
()
self
.
row_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_COL
)
self
.
col_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_ROW
)
self
.
dep_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_DEP
)
self
.
tesseract_dim
,
self
.
tesseract_dep
=
get_tesseract_dim_dep_from_env
()
# partitioning dimension
self
.
input_size_per_partition
=
divide
(
in_features
,
self
.
tesseract_dim
)
self
.
hidden_size_per_partition
=
divide
(
out_features
,
self
.
tesseract_dim
)
# create weight, shape: [k/q, h/q]
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
input_size_per_partition
,
self
.
hidden_size_per_partition
,
**
factory_kwargs
))
# create bias, shape: [h/q]
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
hidden_size_per_partition
,
**
factory_kwargs
))
else
:
self
.
register_parameter
(
'bias'
,
None
)
# initialize parameters
self
.
reset_parameters
()
self
.
_set_tensor_parallel_attributes
()
def
_set_tensor_parallel_attributes
(
self
):
set_tensor_parallel_attribute
(
self
.
weight
)
if
self
.
bias
is
not
None
:
set_tensor_parallel_attribute
(
self
.
bias
)
def
reset_parameters
(
self
)
->
None
:
# setting
fan_in
=
self
.
in_features
a
=
math
.
sqrt
(
5
)
nonlinearity
=
'leaky_relu'
# init weight
std
=
init
.
calculate_gain
(
nonlinearity
,
a
)
/
math
.
sqrt
(
fan_in
)
bound
=
math
.
sqrt
(
3.0
)
*
std
with
seed
(
ParallelMode
.
TENSOR
):
init
.
uniform_
(
self
.
weight
,
-
bound
,
bound
)
# init bias
if
self
.
bias
is
not
None
:
bound
=
1
/
math
.
sqrt
(
fan_in
)
if
fan_in
>
0
else
0
with
seed
(
ParallelMode
.
TENSOR
):
init
.
uniform_
(
self
.
bias
,
-
bound
,
bound
)
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
# input: [m/dq, n/q, k/q]
# output: [m/dq, n/q, h/q]
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
hidden_size_per_partition
,)
output
=
Matmul_AB_2p5D
.
apply
(
x
,
self
.
weight
,
self
.
tesseract_dim
,
self
.
tesseract_dep
,
out_shape
,
self
.
row_rank
,
self
.
col_rank
,
self
.
dep_rank
,
ParallelMode
.
PARALLEL_2P5D_ROW
,
ParallelMode
.
PARALLEL_2P5D_COL
,
ParallelMode
.
PARALLEL_2P5D_DEP
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
,
)
if
self
.
bias
is
not
None
:
if
self
.
skip_bias_add
:
bias
=
Add_Bias_2p5D
.
apply
(
None
,
self
.
bias
,
self
.
hidden_size_per_partition
,
self
.
tesseract_dim
,
self
.
tesseract_dep
,
self
.
row_rank
,
self
.
col_rank
,
self
.
dep_rank
,
ParallelMode
.
PARALLEL_2P5D_ROW
,
ParallelMode
.
PARALLEL_2P5D_COL
,
ParallelMode
.
PARALLEL_2P5D_DEP
,
True
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
return
output
,
bias
else
:
output
=
Add_Bias_2p5D
.
apply
(
output
,
self
.
bias
,
self
.
hidden_size_per_partition
,
self
.
tesseract_dim
,
self
.
tesseract_dep
,
self
.
row_rank
,
self
.
col_rank
,
self
.
dep_rank
,
ParallelMode
.
PARALLEL_2P5D_ROW
,
ParallelMode
.
PARALLEL_2P5D_COL
,
ParallelMode
.
PARALLEL_2P5D_DEP
,
False
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
return
output
else
:
return
output
@
LAYERS
.
register_module
class
LayerNorm2p5D
(
ParallelLayer
):
r
"""Layer Normalization for 2.5D parallelism
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05
:type eps: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def
__init__
(
self
,
normalized_shape
:
int
,
eps
:
float
=
1e-05
,
dtype
=
None
):
super
().
__init__
()
# layer norm config
self
.
normalized_shape
=
normalized_shape
self
.
variance_epsilon
=
eps
# parallel setting
assert_tesseract_initialization
()
self
.
row_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_COL
)
self
.
col_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_ROW
)
self
.
dep_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2P5D_DEP
)
self
.
tesseract_dim
,
self
.
tesseract_dep
=
get_tesseract_dim_dep_from_env
()
# partitioning dimension
self
.
partitioned_partition
=
divide
(
normalized_shape
,
self
.
tesseract_dim
)
# *
# create parameters
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
if
self
.
row_rank
==
0
:
self
.
gamma
=
Parameter
(
torch
.
ones
(
self
.
partitioned_partition
,
**
factory_kwargs
))
self
.
beta
=
Parameter
(
torch
.
zeros
(
self
.
partitioned_partition
,
**
factory_kwargs
))
else
:
self
.
gamma
=
Parameter
(
torch
.
tensor
(
1.0
,
requires_grad
=
True
,
**
factory_kwargs
))
self
.
beta
=
Parameter
(
torch
.
tensor
(
1.0
,
requires_grad
=
True
,
**
factory_kwargs
))
self
.
_set_tensor_parallel_attribute
()
def
_set_tensor_parallel_attribute
(
self
):
set_tensor_parallel_attribute
(
self
.
gamma
)
set_tensor_parallel_attribute
(
self
.
beta
)
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
with
torch
.
no_grad
():
E_x
=
torch
.
sum
(
x
,
dim
=-
1
,
keepdim
=
True
)
# [b/q, s, 1]
torch
.
distributed
.
all_reduce
(
E_x
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2P5D_ROW
))
E_x
/=
self
.
normalized_shape
# Var_x in the block below is the sum of input^2
Var_x
=
torch
.
sum
(
x
*
x
,
dim
=-
1
,
keepdim
=
True
)
# [b/q, s, 1]
torch
.
distributed
.
all_reduce
(
Var_x
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2P5D_ROW
))
Var_x
/=
self
.
normalized_shape
Var_x
=
Var_x
-
E_x
*
E_x
# variance of x [b/q, s, 1]
# this time 1/sqrt(Var_x + epsilon)
Var_x
=
1.0
/
torch
.
sqrt
(
Var_x
+
self
.
variance_epsilon
)
output
=
_LayerNorm_2p5D
.
apply
(
x
,
E_x
,
Var_x
,
self
.
normalized_shape
,
ParallelMode
.
PARALLEL_2P5D_ROW
,
ParallelMode
.
PARALLEL_2P5D_COL
,
ParallelMode
.
PARALLEL_2P5D_DEP
)
bias
=
Add_Bias_2p5D
.
apply
(
None
,
self
.
beta
,
self
.
partitioned_partition
,
self
.
tesseract_dim
,
self
.
tesseract_dep
,
self
.
row_rank
,
self
.
col_rank
,
self
.
dep_rank
,
ParallelMode
.
PARALLEL_2P5D_ROW
,
ParallelMode
.
PARALLEL_2P5D_COL
,
ParallelMode
.
PARALLEL_2P5D_DEP
,
True
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
scale
=
Add_Bias_2p5D
.
apply
(
None
,
self
.
gamma
,
self
.
partitioned_partition
,
self
.
tesseract_dim
,
self
.
tesseract_dep
,
self
.
row_rank
,
self
.
col_rank
,
self
.
dep_rank
,
ParallelMode
.
PARALLEL_2P5D_ROW
,
ParallelMode
.
PARALLEL_2P5D_COL
,
ParallelMode
.
PARALLEL_2P5D_DEP
,
True
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
output
=
torch
.
addcmul
(
bias
,
scale
,
output
)
return
output
colossalai/nn/layer/parallel_3d/__init__.py
0 → 100644
View file @
404ecbdc
from
._operation
import
Matmul_ABT_3D
,
Matmul_ATB_3D
,
Matmul_AB_3D
,
Mul_3D
,
Sum_3D
,
Add_3D
,
Reduce_3D
from
._vit
import
ViTHead3D
,
ViTMLP3D
,
ViTPatchEmbedding3D
,
ViTSelfAttention3D
from
.layers
import
Linear3D
,
LayerNorm3D
__all__
=
[
'Matmul_ABT_3D'
,
'Matmul_ATB_3D'
,
'Matmul_AB_3D'
,
'Mul_3D'
,
'Sum_3D'
,
'Add_3D'
,
'Reduce_3D'
,
'ViTHead3D'
,
'ViTMLP3D'
,
'ViTPatchEmbedding3D'
,
'ViTSelfAttention3D'
,
'Linear3D'
,
'LayerNorm3D'
]
colossalai/nn/layer/parallel_3d/_operation.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from
typing
import
Any
,
Tuple
import
torch
import
torch.distributed
as
dist
from
colossalai.communication
import
all_gather
,
reduce_scatter
,
scatter
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils
import
empty_cache
,
get_current_device
from
torch
import
Tensor
class
Matmul_AB_3D
(
torch
.
autograd
.
Function
):
"""Matrix multiplication for :math:`C = AB`
"""
@
staticmethod
def
forward
(
ctx
:
Any
,
A
:
Tensor
,
B
:
Tensor
,
depth
:
int
,
input_parallel_mode
:
ParallelMode
,
weight_parallel_mode
:
ParallelMode
,
output_parallel_mode
:
ParallelMode
,
input_dim
:
int
=
0
,
weight_dim
:
int
=
-
1
,
output_dim
:
int
=
0
)
->
Tensor
:
# A: [m/q^2, n, k/q]
# B: [k/q, h/q^2]
# C: [m/q^2, n, h/q]
empty_cache
()
ctx
.
save_for_backward
(
A
,
B
)
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
0
],
\
'Invalid shapes: A={}, B={}.'
.
format
(
A
.
shape
,
B
.
shape
)
A_temp
=
all_gather
(
A
,
input_dim
,
input_parallel_mode
)
B_temp
=
all_gather
(
B
,
weight_dim
,
weight_parallel_mode
)
C
=
torch
.
matmul
(
A_temp
,
B_temp
)
out
=
reduce_scatter
(
C
,
output_dim
,
output_parallel_mode
)
ctx
.
depth
=
depth
ctx
.
A_group_parallel_mode
=
input_parallel_mode
ctx
.
B_group_parallel_mode
=
weight_parallel_mode
ctx
.
C_group_parallel_mode
=
output_parallel_mode
ctx
.
A_dim
=
input_dim
ctx
.
B_dim
=
weight_dim
ctx
.
C_dim
=
output_dim
return
out
@
staticmethod
def
backward
(
ctx
:
Any
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
A
,
B
=
ctx
.
saved_tensors
with
torch
.
no_grad
():
A_grad
=
Matmul_ABT_3D
.
apply
(
output_grad
,
B
,
ctx
.
depth
,
ctx
.
C_group_parallel_mode
,
ctx
.
B_group_parallel_mode
,
ctx
.
A_group_parallel_mode
,
ctx
.
C_dim
,
ctx
.
B_dim
,
ctx
.
A_dim
)
B_grad
=
Matmul_ATB_3D
.
apply
(
A
,
output_grad
,
ctx
.
depth
,
ctx
.
A_group_parallel_mode
,
ctx
.
C_group_parallel_mode
,
ctx
.
B_group_parallel_mode
,
ctx
.
A_dim
,
ctx
.
C_dim
,
ctx
.
B_dim
)
return
A_grad
,
B_grad
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
Matmul_ABT_3D
(
torch
.
autograd
.
Function
):
"""Matrix multiplication for :math:`C = AB^T`
"""
@
staticmethod
def
forward
(
ctx
:
Any
,
A
:
Tensor
,
B
:
Tensor
,
depth
:
int
,
input_parallel_mode
:
ParallelMode
,
weight_parallel_mode
:
ParallelMode
,
output_parallel_mode
:
ParallelMode
,
input_dim
:
int
=
0
,
weight_dim
:
int
=
-
1
,
output_dim
:
int
=
0
)
->
Tensor
:
# A: [m/q^2, n, h/q]
# B: [k/q, h/q^2]
# C: [m/q^2, n, k/q]
empty_cache
()
ctx
.
save_for_backward
(
A
,
B
)
A_temp
=
all_gather
(
A
,
input_dim
,
input_parallel_mode
)
B_temp
=
all_gather
(
B
,
weight_dim
,
weight_parallel_mode
)
C
=
torch
.
matmul
(
A_temp
,
B_temp
.
transpose
(
0
,
1
))
out
=
reduce_scatter
(
C
,
output_dim
,
output_parallel_mode
)
ctx
.
depth
=
depth
ctx
.
A_group_parallel_mode
=
input_parallel_mode
ctx
.
B_group_parallel_mode
=
weight_parallel_mode
ctx
.
C_group_parallel_mode
=
output_parallel_mode
ctx
.
A_dim
=
input_dim
ctx
.
B_dim
=
weight_dim
ctx
.
C_dim
=
output_dim
return
out
@
staticmethod
def
backward
(
ctx
:
Any
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
A
,
B
=
ctx
.
saved_tensors
with
torch
.
no_grad
():
A_grad
=
Matmul_AB_3D
.
apply
(
output_grad
,
B
,
ctx
.
depth
,
ctx
.
C_group_parallel_mode
,
ctx
.
B_group_parallel_mode
,
ctx
.
A_group_parallel_mode
,
ctx
.
C_dim
,
ctx
.
B_dim
,
ctx
.
A_dim
)
B_grad
=
Matmul_ATB_3D
.
apply
(
output_grad
,
A
,
ctx
.
depth
,
ctx
.
C_group_parallel_mode
,
ctx
.
A_group_parallel_mode
,
ctx
.
B_group_parallel_mode
,
ctx
.
C_dim
,
ctx
.
A_dim
,
ctx
.
B_dim
)
return
A_grad
,
B_grad
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
Matmul_ATB_3D
(
torch
.
autograd
.
Function
):
"""Matrix multiplication for :math:`C = A^TB`
"""
@
staticmethod
def
forward
(
ctx
:
Any
,
A
:
Tensor
,
B
:
Tensor
,
depth
:
int
,
input_parallel_mode
:
ParallelMode
,
weight_parallel_mode
:
ParallelMode
,
output_parallel_mode
:
ParallelMode
,
input_dim
:
int
=
0
,
weight_dim
:
int
=
0
,
output_dim
:
int
=
-
1
)
->
Tensor
:
# A: [m/q^2, n, k/q]
# B: [m/q^2, n, h/q]
# C: [k/q, h/q^2]
empty_cache
()
ctx
.
save_for_backward
(
A
,
B
)
A_temp
=
all_gather
(
A
,
input_dim
,
input_parallel_mode
)
A_temp
=
A_temp
.
reshape
(
-
1
,
A
.
shape
[
-
1
])
B_temp
=
all_gather
(
B
,
weight_dim
,
weight_parallel_mode
)
B_temp
=
B_temp
.
reshape
(
-
1
,
B
.
shape
[
-
1
])
C
=
torch
.
matmul
(
A_temp
.
transpose
(
0
,
1
),
B_temp
)
out
=
reduce_scatter
(
C
,
output_dim
,
output_parallel_mode
)
ctx
.
depth
=
depth
ctx
.
A_group_parallel_mode
=
input_parallel_mode
ctx
.
B_group_parallel_mode
=
weight_parallel_mode
ctx
.
C_group_parallel_mode
=
output_parallel_mode
ctx
.
A_dim
=
input_dim
ctx
.
B_dim
=
weight_dim
ctx
.
C_dim
=
output_dim
return
out
@
staticmethod
def
backward
(
ctx
:
Any
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
A
,
B
=
ctx
.
saved_tensors
with
torch
.
no_grad
():
A_grad
=
Matmul_ABT_3D
.
apply
(
B
,
output_grad
,
ctx
.
depth
,
ctx
.
B_group_parallel_mode
,
ctx
.
C_group_parallel_mode
,
ctx
.
A_group_parallel_mode
,
ctx
.
B_dim
,
ctx
.
C_dim
,
ctx
.
A_dim
)
B_grad
=
Matmul_AB_3D
.
apply
(
A
,
output_grad
,
ctx
.
depth
,
ctx
.
A_group_parallel_mode
,
ctx
.
C_group_parallel_mode
,
ctx
.
B_group_parallel_mode
,
ctx
.
A_dim
,
ctx
.
C_dim
,
ctx
.
B_dim
)
return
A_grad
,
B_grad
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
Add_3D
(
torch
.
autograd
.
Function
):
"""Matrix add bias: :math:`C = A + b`
"""
@
staticmethod
def
forward
(
ctx
:
Any
,
input_
:
Tensor
,
bias
:
Tensor
,
depth
:
int
,
input_parallel_mode
:
ParallelMode
,
weight_parallel_mode
:
ParallelMode
,
output_parallel_mode
:
ParallelMode
)
->
Tensor
:
# input: [m/q^2, n, h/q]
# bias: [h/q^2]
ranks_in_group
=
gpc
.
get_ranks_in_group
(
input_parallel_mode
)
src_rank
=
ranks_in_group
[
gpc
.
get_local_rank
(
output_parallel_mode
)]
bias_temp
=
bias
.
clone
()
dist
.
broadcast
(
bias_temp
,
src
=
src_rank
,
group
=
gpc
.
get_group
(
input_parallel_mode
))
# [h/q]
bias_temp
=
all_gather
(
bias_temp
,
-
1
,
weight_parallel_mode
)
out
=
input_
+
bias_temp
ctx
.
depth
=
depth
ctx
.
src_rank
=
src_rank
ctx
.
A_group_parallel_mode
=
input_parallel_mode
ctx
.
B_group_parallel_mode
=
weight_parallel_mode
ctx
.
C_group_parallel_mode
=
output_parallel_mode
return
out
@
staticmethod
def
backward
(
ctx
:
Any
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
# output_grad: [m/q^2, n, h/q]
with
torch
.
no_grad
():
# [h/q]
grad
=
torch
.
sum
(
output_grad
,
dim
=
tuple
(
range
(
len
(
output_grad
.
shape
))[:
-
1
]))
bias_grad
=
reduce_scatter
(
grad
,
-
1
,
ctx
.
B_group_parallel_mode
)
dist
.
reduce
(
bias_grad
,
dst
=
ctx
.
src_rank
,
group
=
gpc
.
get_group
(
ctx
.
A_group_parallel_mode
))
if
gpc
.
get_local_rank
(
ctx
.
A_group_parallel_mode
)
!=
gpc
.
get_local_rank
(
ctx
.
C_group_parallel_mode
):
bias_grad
=
None
return
output_grad
,
bias_grad
,
None
,
None
,
None
,
None
class
Mul_3D
(
torch
.
autograd
.
Function
):
"""Matrix multiplication for :math:`C = A * b`
"""
@
staticmethod
def
forward
(
ctx
:
Any
,
input_
:
Tensor
,
bias
:
Tensor
,
depth
:
int
,
input_parallel_mode
:
ParallelMode
,
weight_parallel_mode
:
ParallelMode
,
output_parallel_mode
:
ParallelMode
)
->
Tensor
:
# input: [m/q^2, n, h/q]
# bias: [h/q^2]
ranks_in_group
=
gpc
.
get_ranks_in_group
(
input_parallel_mode
)
src_rank
=
ranks_in_group
[
gpc
.
get_local_rank
(
output_parallel_mode
)]
# [h/q^2]
bias_temp
=
bias
.
clone
()
dist
.
broadcast
(
bias_temp
,
src
=
src_rank
,
group
=
gpc
.
get_group
(
input_parallel_mode
))
# [h/q]
bias_temp
=
all_gather
(
bias_temp
,
-
1
,
weight_parallel_mode
)
empty_cache
()
ctx
.
save_for_backward
(
input_
,
bias_temp
)
out
=
torch
.
mul
(
input_
,
bias_temp
)
ctx
.
depth
=
depth
ctx
.
src_rank
=
src_rank
ctx
.
A_group_parallel_mode
=
input_parallel_mode
ctx
.
B_group_parallel_mode
=
weight_parallel_mode
ctx
.
C_group_parallel_mode
=
output_parallel_mode
return
out
@
staticmethod
def
backward
(
ctx
:
Any
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
# output_grad: [m/q^2, n, h/q]
with
torch
.
no_grad
():
input_
,
bias
=
ctx
.
saved_tensors
# [m/q^2, n, h/q]
input_grad
=
torch
.
mul
(
output_grad
,
bias
)
# [h/q]
grad
=
torch
.
mul
(
output_grad
,
input_
)
grad
=
torch
.
sum
(
grad
,
dim
=
tuple
(
range
(
len
(
output_grad
.
shape
))[:
-
1
]))
bias_grad
=
reduce_scatter
(
grad
,
-
1
,
ctx
.
B_group_parallel_mode
)
dist
.
reduce
(
bias_grad
,
dst
=
ctx
.
src_rank
,
group
=
gpc
.
get_group
(
ctx
.
A_group_parallel_mode
))
if
gpc
.
get_local_rank
(
ctx
.
A_group_parallel_mode
)
!=
gpc
.
get_local_rank
(
ctx
.
C_group_parallel_mode
):
bias_grad
=
None
return
input_grad
,
bias_grad
,
None
,
None
,
None
,
None
class
Sum_3D
(
torch
.
autograd
.
Function
):
"""Compute the sum of input tensors
"""
@
staticmethod
def
forward
(
ctx
:
Any
,
input_
:
Tensor
,
dim
:
int
,
depth
:
int
,
parallel_mode
:
ParallelMode
,
keepdim
:
bool
=
False
)
->
Tensor
:
# input: [m/q^2, n, h/q]
out
=
torch
.
sum
(
input_
,
dim
=
dim
,
keepdim
=
keepdim
)
dist
.
all_reduce
(
out
,
group
=
gpc
.
get_group
(
parallel_mode
))
ctx
.
input_shape
=
input_
.
shape
ctx
.
depth
=
depth
ctx
.
group
=
parallel_mode
ctx
.
dim
=
dim
return
out
@
staticmethod
def
backward
(
ctx
:
Any
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
with
torch
.
no_grad
():
output_grad
=
output_grad
.
contiguous
()
dist
.
all_reduce
(
output_grad
,
group
=
gpc
.
get_group
(
ctx
.
group
))
if
len
(
output_grad
.
shape
)
<
len
(
ctx
.
input_shape
):
output_grad
=
torch
.
unsqueeze
(
output_grad
,
ctx
.
dim
)
dims
=
[
1
for
_
in
range
(
len
(
output_grad
.
shape
))]
dims
[
ctx
.
dim
]
=
ctx
.
input_shape
[
ctx
.
dim
]
input_grad
=
output_grad
.
repeat
(
tuple
(
dims
))
return
input_grad
,
None
,
None
,
None
,
None
,
None
class
Reduce_3D
(
torch
.
autograd
.
Function
):
"""Reduce input tensors
"""
@
staticmethod
def
forward
(
ctx
:
Any
,
input_
:
Tensor
,
depth
:
int
,
parallel_mode
:
ParallelMode
)
->
Tensor
:
dist
.
all_reduce
(
input_
,
group
=
gpc
.
get_group
(
parallel_mode
))
return
input_
.
clone
()
@
staticmethod
def
backward
(
ctx
:
Any
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
return
output_grad
,
None
,
None
class
Slice_3D
(
torch
.
autograd
.
Function
):
"""Slice input tensor
"""
@
staticmethod
def
forward
(
ctx
:
Any
,
input_
:
Tensor
,
dim
:
int
,
depth
:
int
,
parallel_mode
:
ParallelMode
)
->
Tensor
:
rank
=
gpc
.
get_local_rank
(
parallel_mode
)
out
=
torch
.
chunk
(
input_
,
depth
,
dim
=
dim
)[
rank
].
contiguous
()
ctx
.
depth
=
depth
ctx
.
parallel_mode
=
parallel_mode
ctx
.
dim
=
dim
ctx
.
input_shape
=
input_
.
shape
return
out
@
staticmethod
def
backward
(
ctx
:
Any
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
with
torch
.
no_grad
():
input_grad
=
all_gather
(
output_grad
,
ctx
.
dim
,
ctx
.
parallel_mode
)
input_grad
.
reshape
(
ctx
.
input_shape
)
return
input_grad
,
None
,
None
,
None
colossalai/nn/layer/parallel_3d/_utils.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
os
from
colossalai.constants
import
DEPTH_3D
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
torch
import
Tensor
def
get_depth_from_env
()
->
int
:
try
:
depth
=
os
.
environ
[
DEPTH_3D
]
depth
=
int
(
depth
)
assert
depth
>
0
,
'DEPTH must be greater than zero'
return
depth
except
KeyError
as
e
:
raise
EnvironmentError
(
'DEPTH is not found in the current environment, '
'please make sure that you have used the correct process group initializer'
)
def
get_last_group
(
a
,
b
):
mapping
=
{
ParallelMode
.
PARALLEL_3D_INPUT
:
'A'
,
ParallelMode
.
PARALLEL_3D_WEIGHT
:
'B'
,
ParallelMode
.
PARALLEL_3D_OUTPUT
:
'C'
,
}
res
=
chr
(
ord
(
'A'
)
+
ord
(
'B'
)
+
ord
(
'C'
)
-
ord
(
mapping
[
a
])
-
ord
(
mapping
[
b
]))
if
res
==
'A'
:
return
ParallelMode
.
PARALLEL_3D_INPUT
elif
res
==
'B'
:
return
ParallelMode
.
PARALLEL_3D_WEIGHT
elif
res
==
'C'
:
return
ParallelMode
.
PARALLEL_3D_OUTPUT
def
dbg_check_shape
(
tensor
:
Tensor
,
shape
:
tuple
):
rank
=
gpc
.
get_global_rank
()
if
rank
==
0
:
print
(
tensor
.
shape
)
assert
tensor
.
shape
==
shape
,
\
'{} does not match {}'
.
format
(
tensor
.
shape
,
shape
)
colossalai/nn/layer/parallel_3d/_vit.py
0 → 100644
View file @
404ecbdc
import
math
from
typing
import
Tuple
import
torch
import
torch.distributed
as
dist
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.core
import
global_context
as
gpc
from
colossalai.registry
import
LAYERS
from
colossalai.utils
import
checkpoint
,
get_current_device
from
torch
import
Tensor
,
dtype
,
nn
from
.._common_utils
import
ACT2FN
,
divide
,
set_tensor_parallel_attribute
from
..vanilla_vision_transformer.layers
import
to_2tuple
from
._utils
import
get_depth_from_env
from
.layers
import
Linear3D
@
LAYERS
.
register_module
class
ViTPatchEmbedding3D
(
nn
.
Module
):
""" 3D Image to Patch Embedding
:param img_size: iamge size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param in_chans: number of channels of input image
:type in_chans: int
:param embed_size: dimension of embedding
:type embed_size: int
:param drop_prob: dropout probability
:type drop_prob: float
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
"""
def
__init__
(
self
,
img_size
:
int
,
patch_size
:
int
,
in_chans
:
int
,
embed_size
:
int
,
drop_prob
:
float
,
flatten
:
bool
=
True
):
super
().
__init__
()
self
.
depth
=
get_depth_from_env
()
self
.
input_parallel_mode
=
ParallelMode
.
PARALLEL_3D_INPUT
self
.
weight_parallel_mode
=
ParallelMode
.
PARALLEL_3D_WEIGHT
self
.
output_parallel_mode
=
ParallelMode
.
PARALLEL_3D_OUTPUT
img_size
=
to_2tuple
(
img_size
)
patch_size
=
to_2tuple
(
patch_size
)
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
grid_size
=
(
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
])
self
.
embed_size
=
embed_size
self
.
embed_size_per_partition
=
divide
(
self
.
embed_size
,
self
.
depth
)
self
.
num_patches
=
self
.
grid_size
[
0
]
*
self
.
grid_size
[
1
]
self
.
flatten
=
flatten
with
seed
(
ParallelMode
.
TENSOR
):
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
self
.
embed_size_per_partition
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
self
.
embed_size_per_partition
))
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
self
.
num_patches
+
1
,
self
.
embed_size_per_partition
))
self
.
pos_drop
=
nn
.
Dropout
(
drop_prob
)
self
.
_sync_parameters
()
self
.
proj
.
weight
.
register_hook
(
self
.
_sync_grad_hook
)
self
.
proj
.
bias
.
register_hook
(
self
.
_sync_grad_hook
)
self
.
cls_token
.
register_hook
(
self
.
_sync_grad_hook
)
self
.
pos_embed
.
register_hook
(
self
.
_sync_grad_hook
)
self
.
_set_tensor_parallel_attribute
()
def
_set_tensor_parallel_attribute
(
self
):
set_tensor_parallel_attribute
(
self
.
proj
.
weight
)
set_tensor_parallel_attribute
(
self
.
proj
.
bias
)
set_tensor_parallel_attribute
(
self
.
cls_token
)
set_tensor_parallel_attribute
(
self
.
pos_embed
)
def
groups_for_next_layer
(
self
)
->
Tuple
[
ParallelMode
,
ParallelMode
]:
return
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
def
_sync_parameters
(
self
):
self
.
to
(
get_current_device
())
weight_src_rank
=
gpc
.
get_ranks_in_group
(
self
.
weight_parallel_mode
)[
0
]
dist
.
broadcast
(
self
.
proj
.
weight
,
src
=
weight_src_rank
,
group
=
gpc
.
get_group
(
self
.
weight_parallel_mode
))
dist
.
broadcast
(
self
.
proj
.
bias
,
src
=
weight_src_rank
,
group
=
gpc
.
get_group
(
self
.
weight_parallel_mode
))
input_src_rank
=
gpc
.
get_ranks_in_group
(
self
.
input_parallel_mode
)[
0
]
dist
.
broadcast
(
self
.
proj
.
weight
,
src
=
input_src_rank
,
group
=
gpc
.
get_group
(
self
.
input_parallel_mode
))
dist
.
broadcast
(
self
.
proj
.
bias
,
src
=
input_src_rank
,
group
=
gpc
.
get_group
(
self
.
input_parallel_mode
))
set_tensor_parallel_attribute
(
self
.
proj
.
weight
)
set_tensor_parallel_attribute
(
self
.
proj
.
bias
)
set_tensor_parallel_attribute
(
self
.
cls_token
)
set_tensor_parallel_attribute
(
self
.
pos_embed
)
def
_sync_grad_hook
(
self
,
grad
)
->
None
:
dist
.
all_reduce
(
grad
,
group
=
gpc
.
get_group
(
self
.
input_parallel_mode
))
dist
.
all_reduce
(
grad
,
group
=
gpc
.
get_group
(
self
.
weight_parallel_mode
))
return
grad
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
B
,
C
,
H
,
W
=
x
.
shape
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
],
\
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
x
=
self
.
proj
(
x
)
if
self
.
flatten
:
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
# BCHW -> BNC
# split a partition from embedded states
x
=
torch
.
chunk
(
x
,
self
.
depth
,
dim
=
0
)[
gpc
.
get_local_rank
(
self
.
weight_parallel_mode
)].
contiguous
()
x
=
torch
.
chunk
(
x
,
self
.
depth
,
dim
=
0
)[
gpc
.
get_local_rank
(
self
.
input_parallel_mode
)].
contiguous
()
# add cls token & pos embedding
# [b/q^2,s,h/q] --> [b/q^2, 1+s, h/q]
cls_token
=
self
.
cls_token
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
)
x
=
torch
.
cat
((
cls_token
,
x
),
dim
=
1
)
with
seed
(
ParallelMode
.
TENSOR
):
x
=
self
.
pos_drop
(
x
+
self
.
pos_embed
)
return
x
@
LAYERS
.
register_module
class
ViTSelfAttention3D
(
nn
.
Module
):
"""Self-attention layer for 3D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_probs_dropout_prob: dropout probability for attention layers
:type attention_probs_dropout_prob: bool
:param hidden_dropout_prob: dropout probability for hidden layers
:type hidden_dropout_prob: bool
:param depth: the 3D parallelism depth
:type depth: int
:param input_parallel_mode: parallel mode of input tensor
:type input_parallel_mode: ParallelMode
:param weight_parallel_mode: parallel mode of weight
:type weight_parallel_mode: ParallelMode
:param dtype: dtype of parameters, defaults to None
:type dtype: dtype, optional
:param bias: whether to add bias, defaults to True
:type bias: bool, optional
"""
def
__init__
(
self
,
hidden_size
:
int
,
num_attention_heads
:
int
,
attention_probs_dropout_prob
:
float
,
hidden_dropout_prob
:
float
,
dtype
:
dtype
=
None
,
bias
:
bool
=
True
,
checkpoint
:
bool
=
False
):
super
().
__init__
()
self
.
depth
=
get_depth_from_env
()
self
.
input_parallel_mode
=
ParallelMode
.
PARALLEL_3D_INPUT
self
.
weight_parallel_mode
=
ParallelMode
.
PARALLEL_3D_WEIGHT
self
.
output_parallel_mode
=
ParallelMode
.
PARALLEL_3D_OUTPUT
self
.
hidden_size
=
hidden_size
self
.
num_attention_heads
=
divide
(
num_attention_heads
,
self
.
depth
)
self
.
attention_head_size
=
divide
(
hidden_size
,
num_attention_heads
)
self
.
all_head_size
=
self
.
num_attention_heads
*
self
.
attention_head_size
self
.
checkpoint
=
checkpoint
self
.
query_key_value
=
Linear3D
(
self
.
hidden_size
,
3
*
self
.
hidden_size
,
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
,
dtype
=
dtype
,
bias
=
bias
)
self
.
attention_dropout
=
nn
.
Dropout
(
attention_probs_dropout_prob
)
self
.
dense
=
Linear3D
(
self
.
hidden_size
,
self
.
hidden_size
,
self
.
output_parallel_mode
,
self
.
weight_parallel_mode
,
dtype
=
dtype
,
bias
=
bias
)
self
.
dropout
=
nn
.
Dropout
(
hidden_dropout_prob
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
groups_for_next_layer
(
self
)
->
Tuple
[
ParallelMode
,
ParallelMode
]:
return
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
def
_forward
(
self
,
hidden_states
:
Tensor
)
->
Tensor
:
query_key_value
=
self
.
query_key_value
(
hidden_states
)
new_qkv_shape
=
query_key_value
.
shape
[:
-
1
]
+
\
(
self
.
num_attention_heads
,
3
*
self
.
attention_head_size
)
query_key_value
=
query_key_value
.
view
(
new_qkv_shape
)
query_key_value
=
query_key_value
.
permute
((
0
,
2
,
1
,
3
))
query_layer
,
key_layer
,
value_layer
=
torch
.
chunk
(
query_key_value
,
3
,
dim
=-
1
)
attention_scores
=
torch
.
matmul
(
query_layer
,
key_layer
.
transpose
(
-
1
,
-
2
))
attention_scores
=
attention_scores
/
math
.
sqrt
(
self
.
attention_head_size
)
attention_probs
=
self
.
softmax
(
attention_scores
)
with
seed
(
ParallelMode
.
TENSOR
):
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
context_layer
=
torch
.
matmul
(
attention_probs
,
value_layer
)
context_layer
=
context_layer
.
transpose
(
1
,
2
)
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
self
.
all_head_size
,
)
context_layer
=
context_layer
.
reshape
(
new_context_layer_shape
)
output
=
self
.
dense
(
context_layer
)
with
seed
(
ParallelMode
.
TENSOR
):
output
=
self
.
dropout
(
output
)
return
output
def
_checkpoint_forward
(
self
,
hidden_states
:
Tensor
)
->
Tensor
:
return
checkpoint
(
self
.
_forward
,
hidden_states
)
def
forward
(
self
,
hidden_states
:
Tensor
)
->
Tensor
:
if
self
.
checkpoint
:
return
self
.
_checkpoint_forward
(
hidden_states
)
else
:
return
self
.
_forward
(
hidden_states
)
@
LAYERS
.
register_module
class
ViTMLP3D
(
nn
.
Module
):
"""[summary]
:param hidden_size: hidden size
:type hidden_size: int
:param mlp_ratio: hidden size of MLP divided by embedding dim
:type mlp_ratio: int
:param hidden_dropout_prob: dropout probability for hidden layers
:type hidden_dropout_prob: float
:param hidden_act: activation function for hidden layers
:type hidden_act: str
:param depth: the 3D parallelism depth
:type depth: int
:param input_parallel_mode: parallel mode of input tensor
:type input_parallel_mode: ParallelMode
:param weight_parallel_mode: parallel mode of weight
:type weight_parallel_mode: ParallelMode
:param dtype: dtype of parameters, defaults to None
:type dtype: dtype, optional
:param bias: whether to add bias, defaults to True
:type bias: bool, optional
"""
def
__init__
(
self
,
hidden_size
:
int
,
mlp_ratio
:
int
,
hidden_dropout_prob
:
float
,
hidden_act
:
str
=
'gelu'
,
dtype
:
dtype
=
None
,
bias
:
bool
=
True
,
checkpoint
:
bool
=
False
):
super
().
__init__
()
self
.
depth
=
get_depth_from_env
()
self
.
input_parallel_mode
=
ParallelMode
.
PARALLEL_3D_INPUT
self
.
weight_parallel_mode
=
ParallelMode
.
PARALLEL_3D_WEIGHT
self
.
output_parallel_mode
=
ParallelMode
.
PARALLEL_3D_OUTPUT
self
.
hidden_size
=
hidden_size
self
.
mlp_ratio
=
mlp_ratio
self
.
checkpoint
=
checkpoint
self
.
dense_1
=
Linear3D
(
self
.
hidden_size
,
self
.
mlp_ratio
*
self
.
hidden_size
,
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
,
dtype
=
dtype
,
bias
=
bias
)
self
.
activation_func
=
ACT2FN
[
hidden_act
]
self
.
dense_2
=
Linear3D
(
self
.
mlp_ratio
*
self
.
hidden_size
,
self
.
hidden_size
,
self
.
output_parallel_mode
,
self
.
weight_parallel_mode
,
dtype
=
dtype
,
bias
=
bias
)
self
.
dropout
=
nn
.
Dropout
(
hidden_dropout_prob
)
def
groups_for_next_layer
(
self
)
->
Tuple
[
ParallelMode
,
ParallelMode
]:
return
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
def
_forward
(
self
,
hidden_states
:
Tensor
)
->
Tensor
:
intermediate_output
=
self
.
dense_1
(
hidden_states
)
intermediate_output
=
self
.
activation_func
(
intermediate_output
)
output
=
self
.
dense_2
(
intermediate_output
)
with
seed
(
ParallelMode
.
TENSOR
):
output
=
self
.
dropout
(
output
)
return
output
def
_checkpoint_forward
(
self
,
hidden_states
:
Tensor
)
->
Tensor
:
return
checkpoint
(
self
.
_forward
,
hidden_states
)
def
forward
(
self
,
hidden_states
:
Tensor
)
->
Tensor
:
if
self
.
checkpoint
:
return
self
.
_checkpoint_forward
(
hidden_states
)
else
:
return
self
.
_forward
(
hidden_states
)
@
LAYERS
.
register_module
class
ViTHead3D
(
nn
.
Module
):
"""Output layer for 3D parallel Vision Transformer
:param in_features: size of input tensor
:type in_features: int
:param num_classes: number of classes
:type num_classes: int
:param depth: the 3D parallelism depth
:type depth: int
:param input_parallel_mode: parallel mode of input tensor
:type input_parallel_mode: ParallelMode
:param weight_parallel_mode: parallel mode of weight
:type weight_parallel_mode: ParallelMode
:param dtype: dtype of parameters, defaults to None
:type dtype: dtype, optional
:param bias: whether to add bias, defaults to True
:type bias: bool, optional
"""
def
__init__
(
self
,
in_features
:
int
,
num_classes
:
int
,
dtype
:
dtype
=
None
,
bias
:
bool
=
True
):
super
().
__init__
()
self
.
depth
=
get_depth_from_env
()
self
.
input_parallel_mode
=
ParallelMode
.
PARALLEL_3D_INPUT
self
.
weight_parallel_mode
=
ParallelMode
.
PARALLEL_3D_WEIGHT
self
.
output_parallel_mode
=
ParallelMode
.
PARALLEL_3D_OUTPUT
self
.
in_features
=
in_features
self
.
num_classes
=
num_classes
out_features
=
math
.
ceil
(
self
.
num_classes
/
(
self
.
depth
**
2
))
*
(
self
.
depth
**
2
)
self
.
num_classes_per_partition
=
divide
(
self
.
num_classes
,
self
.
depth
)
self
.
linear
=
Linear3D
(
self
.
in_features
,
out_features
,
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
,
dtype
=
dtype
,
bias
=
bias
)
def
groups_for_next_layer
(
self
)
->
Tuple
[
ParallelMode
,
ParallelMode
]:
return
self
.
linear
.
groups_for_next_layer
()
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
# [b/q^2, s, h/q] --> [b/q^2, h/q]
x
=
x
[:,
0
]
# [b/q^2, h/q] --> [b/q^2, c/q]
x
=
self
.
linear
(
x
)
return
x
[:,
:
self
.
num_classes_per_partition
]
def
extra_repr
(
self
):
return
'in_features={}, num_classes={}'
.
format
(
self
.
in_features
,
self
.
num_classes
)
Prev
1
2
3
4
5
6
7
8
…
21
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment