Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4a0f8c2c
Commit
4a0f8c2c
authored
Mar 09, 2022
by
Yuer867
Committed by
Frank Lee
Mar 11, 2022
Browse files
fix format parallel_2p5d (#357)
parent
7eb87f51
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
81 additions
and
49 deletions
+81
-49
colossalai/nn/layer/parallel_2p5d/_operation.py
colossalai/nn/layer/parallel_2p5d/_operation.py
+75
-45
colossalai/nn/layer/parallel_2p5d/_utils.py
colossalai/nn/layer/parallel_2p5d/_utils.py
+2
-1
colossalai/nn/layer/parallel_2p5d/layers.py
colossalai/nn/layer/parallel_2p5d/layers.py
+4
-3
No files found.
colossalai/nn/layer/parallel_2p5d/_operation.py
View file @
4a0f8c2c
...
@@ -166,6 +166,7 @@ class Matmul_AB_2p5D(torch.autograd.Function):
...
@@ -166,6 +166,7 @@ class Matmul_AB_2p5D(torch.autograd.Function):
:param tensor_parallel_size: tensor parallel size
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
:type tensor_parallel_size: int
"""
"""
@
staticmethod
@
staticmethod
@
custom_fwd
(
cast_inputs
=
torch
.
float16
)
@
custom_fwd
(
cast_inputs
=
torch
.
float16
)
def
forward
(
ctx
:
Any
,
A
:
Tensor
,
B
:
Tensor
,
tesseract_dim
:
int
,
out_shape
:
Tuple
[
int
,
...],
row_rank
:
int
,
def
forward
(
ctx
:
Any
,
A
:
Tensor
,
B
:
Tensor
,
tesseract_dim
:
int
,
out_shape
:
Tuple
[
int
,
...],
row_rank
:
int
,
...
@@ -197,9 +198,13 @@ class Matmul_AB_2p5D(torch.autograd.Function):
...
@@ -197,9 +198,13 @@ class Matmul_AB_2p5D(torch.autograd.Function):
row_group
=
gpc
.
get_group
(
row_parallel_mode
)
row_group
=
gpc
.
get_group
(
row_parallel_mode
)
col_group
=
gpc
.
get_group
(
col_parallel_mode
)
col_group
=
gpc
.
get_group
(
col_parallel_mode
)
src_a
=
tesseract_dim
*
row_rank
+
tesseract_dim
**
2
*
dep_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
src_a
=
\
tesseract_dim
*
row_rank
+
tesseract_dim
**
2
*
dep_rank
+
\
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
pipeline_parallel_rank
*
tensor_parallel_size
src_b
=
col_rank
+
tesseract_dim
**
2
*
dep_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
src_b
=
\
col_rank
+
tesseract_dim
**
2
*
dep_rank
+
\
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
pipeline_parallel_rank
*
tensor_parallel_size
opa
=
[
None
]
*
2
opa
=
[
None
]
*
2
...
@@ -295,6 +300,7 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
...
@@ -295,6 +300,7 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
:param tensor_parallel_size: tensor parallel size
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
:type tensor_parallel_size: int
"""
"""
@
staticmethod
@
staticmethod
@
custom_fwd
(
cast_inputs
=
torch
.
float16
)
@
custom_fwd
(
cast_inputs
=
torch
.
float16
)
def
forward
(
ctx
:
Any
,
A
:
Tensor
,
B
:
Tensor
,
tesseract_dim
:
int
,
out_shape
:
Tuple
[
int
,
...],
row_rank
:
int
,
def
forward
(
ctx
:
Any
,
A
:
Tensor
,
B
:
Tensor
,
tesseract_dim
:
int
,
out_shape
:
Tuple
[
int
,
...],
row_rank
:
int
,
...
@@ -323,9 +329,13 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
...
@@ -323,9 +329,13 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
row_group
=
gpc
.
get_group
(
row_parallel_mode
)
row_group
=
gpc
.
get_group
(
row_parallel_mode
)
col_group
=
gpc
.
get_group
(
col_parallel_mode
)
col_group
=
gpc
.
get_group
(
col_parallel_mode
)
src_b
=
col_rank
+
tesseract_dim
**
2
*
dep_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
src_b
=
\
col_rank
+
tesseract_dim
**
2
*
dep_rank
+
\
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
pipeline_parallel_rank
*
tensor_parallel_size
src_c
=
tesseract_dim
*
row_rank
+
tesseract_dim
**
2
*
dep_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
src_c
=
\
tesseract_dim
*
row_rank
+
tesseract_dim
**
2
*
dep_rank
+
\
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
pipeline_parallel_rank
*
tensor_parallel_size
opb
=
[
None
]
*
2
opb
=
[
None
]
*
2
...
@@ -429,6 +439,7 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
...
@@ -429,6 +439,7 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
:param tensor_parallel_size: tensor parallel size
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
:type tensor_parallel_size: int
"""
"""
@
staticmethod
@
staticmethod
@
custom_fwd
(
cast_inputs
=
torch
.
float16
)
@
custom_fwd
(
cast_inputs
=
torch
.
float16
)
def
forward
(
ctx
:
Any
,
A
:
Tensor
,
B
:
Tensor
,
tesseract_dim
:
int
,
out_shape
:
Tuple
[
int
,
...],
row_rank
:
int
,
def
forward
(
ctx
:
Any
,
A
:
Tensor
,
B
:
Tensor
,
tesseract_dim
:
int
,
out_shape
:
Tuple
[
int
,
...],
row_rank
:
int
,
...
@@ -457,9 +468,13 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
...
@@ -457,9 +468,13 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
row_group
=
gpc
.
get_group
(
row_parallel_mode
)
row_group
=
gpc
.
get_group
(
row_parallel_mode
)
col_group
=
gpc
.
get_group
(
col_parallel_mode
)
col_group
=
gpc
.
get_group
(
col_parallel_mode
)
src_a
=
tesseract_dim
*
row_rank
+
tesseract_dim
**
2
*
dep_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
src_a
=
\
tesseract_dim
*
row_rank
+
tesseract_dim
**
2
*
dep_rank
+
\
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
pipeline_parallel_rank
*
tensor_parallel_size
src_c
=
col_rank
+
tesseract_dim
**
2
*
dep_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
src_c
=
\
col_rank
+
tesseract_dim
**
2
*
dep_rank
+
\
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
pipeline_parallel_rank
*
tensor_parallel_size
opa
=
[
None
]
*
2
opa
=
[
None
]
*
2
...
@@ -540,7 +555,9 @@ class _Add_Bias_2p5D(torch.autograd.Function):
...
@@ -540,7 +555,9 @@ class _Add_Bias_2p5D(torch.autograd.Function):
bias_temp
=
bias
.
clone
()
bias_temp
=
bias
.
clone
()
else
:
else
:
bias_temp
=
torch
.
zeros
(
output_size_per_partition
,
dtype
=
bias
.
dtype
,
device
=
get_current_device
())
bias_temp
=
torch
.
zeros
(
output_size_per_partition
,
dtype
=
bias
.
dtype
,
device
=
get_current_device
())
src_rank
=
col_rank
+
dep_rank
*
tesseract_dim
**
2
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
src_rank
=
\
col_rank
+
dep_rank
*
tesseract_dim
**
2
+
\
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
broadcast
(
bias_temp
,
src
=
src_rank
,
group
=
get_parallel_group
(
col_parallel_mode
))
dist
.
broadcast
(
bias_temp
,
src
=
src_rank
,
group
=
get_parallel_group
(
col_parallel_mode
))
...
@@ -575,27 +592,37 @@ class _Add_Bias_2p5D(torch.autograd.Function):
...
@@ -575,27 +592,37 @@ class _Add_Bias_2p5D(torch.autograd.Function):
tensor_parallel_size
=
ctx
.
tensor_parallel_size
tensor_parallel_size
=
ctx
.
tensor_parallel_size
if
ctx
.
bias
:
if
ctx
.
bias
:
dst_rank
=
col_rank
+
dep_rank
*
(
dst_rank
=
\
tesseract_dim
**
2
)
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
col_rank
+
dep_rank
*
(
tesseract_dim
**
2
)
+
\
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
reduce
(
output_grad
,
dst
=
dst_rank
,
group
=
get_parallel_group
(
col_parallel_mode
))
dist
.
reduce
(
output_grad
,
dst
=
dst_rank
,
group
=
get_parallel_group
(
col_parallel_mode
))
if
row_rank
==
0
:
if
row_rank
==
0
:
return
None
,
output_grad
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
\
None
,
output_grad
,
None
,
None
,
None
,
None
,
None
,
None
,
\
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
else
:
else
:
grad_tmp
=
torch
.
zeros_like
(
output_grad
)
grad_tmp
=
torch
.
zeros_like
(
output_grad
)
return
None
,
grad_tmp
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
\
None
,
grad_tmp
,
None
,
None
,
None
,
None
,
None
,
None
,
\
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
else
:
else
:
reduce_dim
=
tuple
(
range
(
output_grad
.
ndim
-
1
))
reduce_dim
=
tuple
(
range
(
output_grad
.
ndim
-
1
))
reduce
=
torch
.
sum
(
output_grad
,
dim
=
reduce_dim
)
reduce
=
torch
.
sum
(
output_grad
,
dim
=
reduce_dim
)
dst_rank
=
col_rank
+
dep_rank
*
(
dst_rank
=
\
tesseract_dim
**
2
)
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
col_rank
+
dep_rank
*
(
tesseract_dim
**
2
)
+
\
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
reduce
(
reduce
,
dst
=
dst_rank
,
group
=
get_parallel_group
(
col_parallel_mode
))
dist
.
reduce
(
reduce
,
dst
=
dst_rank
,
group
=
get_parallel_group
(
col_parallel_mode
))
if
row_rank
==
0
:
if
row_rank
==
0
:
return
output_grad
,
reduce
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
\
output_grad
,
reduce
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
\
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
else
:
else
:
reduce_tmp
=
torch
.
zeros_like
(
reduce
)
reduce_tmp
=
torch
.
zeros_like
(
reduce
)
return
output_grad
,
reduce_tmp
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
\
output_grad
,
reduce_tmp
,
None
,
None
,
None
,
None
,
None
,
None
,
\
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
def
add_bias_2p5d
(
input
:
Tensor
,
bias
:
Tensor
,
output_size_per_partition
:
int
,
tesseract_dim
:
int
,
row_rank
:
int
,
def
add_bias_2p5d
(
input
:
Tensor
,
bias
:
Tensor
,
output_size_per_partition
:
int
,
tesseract_dim
:
int
,
row_rank
:
int
,
...
@@ -621,7 +648,8 @@ def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, t
...
@@ -621,7 +648,8 @@ def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, t
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion
:type skip_bias_add: bool
:type skip_bias_add: bool
:param data_parallel_rank: data parallel rank
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:type data_parallel_rank: int
...
@@ -652,6 +680,7 @@ class _Layernorm2p5D(torch.autograd.Function):
...
@@ -652,6 +680,7 @@ class _Layernorm2p5D(torch.autograd.Function):
:param row_parallel_mode: row parallel mode
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
"""
"""
@
staticmethod
@
staticmethod
@
custom_fwd
(
cast_inputs
=
torch
.
float32
)
@
custom_fwd
(
cast_inputs
=
torch
.
float32
)
def
forward
(
ctx
:
Any
,
input
:
Tensor
,
E_x
:
Tensor
,
Var_x
:
Tensor
,
hidden_size
:
int
,
def
forward
(
ctx
:
Any
,
input
:
Tensor
,
E_x
:
Tensor
,
Var_x
:
Tensor
,
hidden_size
:
int
,
...
@@ -748,6 +777,7 @@ class SplitFirst(torch.autograd.Function):
...
@@ -748,6 +777,7 @@ class SplitFirst(torch.autograd.Function):
:param col_parallel_mode: column parallel mode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
"""
"""
@
staticmethod
@
staticmethod
@
custom_fwd
(
cast_inputs
=
torch
.
float16
)
@
custom_fwd
(
cast_inputs
=
torch
.
float16
)
def
forward
(
ctx
:
Any
,
inputs
:
Tensor
,
tesseract_dim
:
int
,
col_parallel_mode
:
ParallelMode
)
->
Tensor
:
def
forward
(
ctx
:
Any
,
inputs
:
Tensor
,
tesseract_dim
:
int
,
col_parallel_mode
:
ParallelMode
)
->
Tensor
:
...
@@ -762,7 +792,7 @@ class SplitFirst(torch.autograd.Function):
...
@@ -762,7 +792,7 @@ class SplitFirst(torch.autograd.Function):
@
staticmethod
@
staticmethod
@
custom_bwd
@
custom_bwd
def
backward
(
ctx
:
Any
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
def
backward
(
ctx
:
Any
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
grad_shape
=
(
ctx
.
batch_size
,
)
+
output_grad
.
shape
[
1
:]
grad_shape
=
(
ctx
.
batch_size
,)
+
output_grad
.
shape
[
1
:]
grad
=
torch
.
empty
(
grad_shape
,
dtype
=
output_grad
.
dtype
,
device
=
get_current_device
())
grad
=
torch
.
empty
(
grad_shape
,
dtype
=
output_grad
.
dtype
,
device
=
get_current_device
())
dist
.
all_gather
(
list
(
grad
.
chunk
(
ctx
.
tesseract_dim
,
dim
=
0
)),
dist
.
all_gather
(
list
(
grad
.
chunk
(
ctx
.
tesseract_dim
,
dim
=
0
)),
output_grad
.
contiguous
(),
output_grad
.
contiguous
(),
...
...
colossalai/nn/layer/parallel_2p5d/_utils.py
View file @
4a0f8c2c
...
@@ -21,4 +21,5 @@ def assert_tesseract_initialization():
...
@@ -21,4 +21,5 @@ def assert_tesseract_initialization():
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_2P5D_ROW
)
and
\
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_2P5D_ROW
)
and
\
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_2P5D_DEP
)
and
\
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_2P5D_DEP
)
and
\
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_2P5D_XZ
),
\
gpc
.
is_initialized
(
ParallelMode
.
PARALLEL_2P5D_XZ
),
\
'Both PARALLEL_2P5D_COL, PARALLEL_2P5D_ROW, PARALLEL_2P5D_DEP and PARALLEL_2P5D_XZ must be initialized by the process group initializer'
'Both PARALLEL_2P5D_COL, PARALLEL_2P5D_ROW, PARALLEL_2P5D_DEP and PARALLEL_2P5D_XZ '
\
'must be initialized by the process group initializer'
colossalai/nn/layer/parallel_2p5d/layers.py
View file @
4a0f8c2c
...
@@ -134,8 +134,9 @@ class LayerNorm2p5D(ParallelLayer):
...
@@ -134,8 +134,9 @@ class LayerNorm2p5D(ParallelLayer):
r
"""
r
"""
Layer Normalization for 2.5D parallelism
Layer Normalization for 2.5D parallelism
:param normalized_shape: input shape from an expected input
:param normalized_shape: input shape from an expected input of size.
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
:type normalized_shape: int
...
@@ -431,7 +432,7 @@ class VocabParallelEmbedding2p5D(torch.nn.Module):
...
@@ -431,7 +432,7 @@ class VocabParallelEmbedding2p5D(torch.nn.Module):
def
_fill_padding_idx_with_zero
(
self
)
->
None
:
def
_fill_padding_idx_with_zero
(
self
)
->
None
:
if
self
.
padding_idx
is
not
None
and
\
if
self
.
padding_idx
is
not
None
and
\
self
.
padding_idx
>=
self
.
vocab_start_index
and
self
.
padding_idx
<
self
.
vocab_end_index
:
self
.
vocab_start_index
<=
self
.
padding_idx
<
self
.
vocab_end_index
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
self
.
weight
[
self
.
padding_idx
-
self
.
vocab_start_index
].
fill_
(
0
)
self
.
weight
[
self
.
padding_idx
-
self
.
vocab_start_index
].
fill_
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment