Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cbb6436f
Commit
cbb6436f
authored
Mar 09, 2022
by
DouJS
Committed by
Frank Lee
Mar 11, 2022
Browse files
fix format for dir-[parallel_3d] (#333)
parent
eaac03ae
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
17 deletions
+22
-17
colossalai/nn/layer/parallel_3d/_operation.py
colossalai/nn/layer/parallel_3d/_operation.py
+14
-12
colossalai/nn/layer/parallel_3d/layers.py
colossalai/nn/layer/parallel_3d/layers.py
+8
-5
No files found.
colossalai/nn/layer/parallel_3d/_operation.py
View file @
cbb6436f
...
...
@@ -244,7 +244,7 @@ class _Layernorm3D(torch.autograd.Function):
def
layernorm_3d
(
input_
:
Tensor
,
weight
:
Tensor
,
bias
:
Tensor
,
normalized_shape
:
int
,
eps
:
float
,
input_parallel_mode
:
ParallelMode
,
weight_parallel_mode
:
ParallelMode
,
output_parallel_mode
:
ParallelMode
)
->
Tensor
:
"""
r
"""
3D parallel Layernorm
:param input_: input maxtrix
...
...
@@ -253,8 +253,9 @@ def layernorm_3d(input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape:
:type weight: torch.tensor
:param bias: matrix of bias
:type bias: torch.tensor
:param normalized_shape: input shape from an expected input
of size. :math:`[*
\t
imes
\t
ext{normalized_shape}[0]
\t
imes
\t
ext{normalized_shape}[1]
\t
imes \ldots
\t
imes
\t
ext{normalized_shape}[-1]]`
:param normalized_shape: input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
...
...
@@ -282,7 +283,7 @@ def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Te
:type tensor: torch.Tensor
:type dim: int
:type parallel_mode: colossalai.context.parallel_mode.ParallelMode
:return output: Splitted tensor
:rtype output: torch.Tensor
"""
...
...
@@ -294,9 +295,9 @@ def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Te
def
split_batch_3d
(
input_
:
Tensor
,
dim
:
int
=
0
,
input_parallel_mode
:
ParallelMode
=
ParallelMode
.
PARALLEL_3D_INPUT
,
weight_parallel_mode
:
ParallelMode
=
ParallelMode
.
PARALLEL_3D_WEIGHT
)
->
Tensor
:
dim
:
int
=
0
,
input_parallel_mode
:
ParallelMode
=
ParallelMode
.
PARALLEL_3D_INPUT
,
weight_parallel_mode
:
ParallelMode
=
ParallelMode
.
PARALLEL_3D_WEIGHT
)
->
Tensor
:
"""Splits 3D tensor in batch
:param input_: Input tensor
:param dim: Specified dimension in which to split
...
...
@@ -333,8 +334,8 @@ class _ReduceTensor3D(torch.autograd.Function):
def
reduce_tensor_3d
(
tensor
:
Tensor
,
parallel_mode
:
ParallelMode
)
->
Tensor
:
"""
All-reduce the input
.
All-reduce the input
:param tensor: Input tensor
:param parallel_mode: Parallel mode
"""
...
...
@@ -359,7 +360,7 @@ class _AllGatherTensor3D(torch.autograd.Function):
def
all_gather_tensor_3d
(
tensor
:
Tensor
,
dim
:
int
,
parallel_mode
:
ParallelMode
)
->
Tensor
:
"""
All-reduce the gradient in backward pass.
:param tensor: Input tensor
:param parallel_mode: Parallel mode
"""
...
...
@@ -383,7 +384,7 @@ class _ReduceScatterTensor3D(torch.autograd.Function):
def
reduce_scatter_tensor_3d
(
tensor
:
Tensor
,
dim
:
int
,
parallel_mode
:
ParallelMode
)
->
Tensor
:
"""
Reduce-scatter the input.
:param tensor: Input tensor
:param dim: Dimension to scatter
:param parallel_mode: Parallel mode
...
...
@@ -431,7 +432,8 @@ def reduce_by_batch_3d(tensor: Tensor,
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param weight_parallel_mode: weight parallel mode
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size), default to False
:param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size),
default to False
:type reduce_mean: int, optional
"""
return
_ReduceByBatch3D
.
apply
(
tensor
,
input_parallel_mode
,
weight_parallel_mode
,
reduce_mean
)
...
...
colossalai/nn/layer/parallel_3d/layers.py
View file @
cbb6436f
...
...
@@ -17,7 +17,8 @@ from torch import Tensor
from
torch.nn
import
Parameter
from
..utils
import
divide
,
set_tensor_parallel_attribute_by_partition
,
to_2tuple
from
._operation
import
*
from
._operation
import
layernorm_3d
,
linear_3d
,
classifier_3d
,
split_tensor_3d
from
._operation
import
all_gather_tensor_3d
,
reduce_scatter_tensor_3d
,
broadcast_weight_3d_from_diagonal
from
._utils
import
get_depth_from_env
,
get_last_group
,
get_parallel_mode_from_env
,
swap_in_out_group
...
...
@@ -26,8 +27,9 @@ class LayerNorm3D(ParallelLayer):
r
"""
Layer Normalization for 3D parallelism
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
:param normalized_shape: input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
...
...
@@ -38,6 +40,7 @@ class LayerNorm3D(ParallelLayer):
"""
def
__init__
(
self
,
normalized_shape
:
int
,
eps
:
float
=
1e-12
,
dtype
=
None
):
super
().
__init__
()
self
.
input_parallel_mode
=
get_parallel_mode_from_env
(
INPUT_GROUP_3D
)
self
.
weight_parallel_mode
=
get_parallel_mode_from_env
(
WEIGHT_GROUP_3D
)
...
...
@@ -405,7 +408,7 @@ class PatchEmbedding3D(ParallelLayer):
input_
=
split_tensor_3d
(
input_
,
0
,
self
.
input_parallel_mode
)
output
=
F
.
conv2d
(
input_
,
self
.
weight
,
self
.
bias
,
stride
=
self
.
patch_size
)
if
self
.
flatten
:
output
=
output
.
flatten
(
2
).
transpose
(
1
,
2
)
# BCHW -> BNC
output
=
output
.
flatten
(
2
).
transpose
(
1
,
2
)
# BCHW -> BNC
cls_token
=
self
.
cls_token
.
expand
(
output
.
shape
[
0
],
-
1
,
-
1
)
output
=
torch
.
cat
((
cls_token
,
output
),
dim
=
1
)
...
...
@@ -549,7 +552,7 @@ class VocabParallelEmbedding3D(torch.nn.Module):
def
_fill_padding_idx_with_zero
(
self
)
->
None
:
if
self
.
padding_idx
is
not
None
and
\
self
.
padding_idx
>=
self
.
vocab_start_index
and
self
.
padding_idx
<
self
.
vocab_end_index
:
self
.
padding_idx
>=
self
.
vocab_start_index
and
self
.
padding_idx
<
self
.
vocab_end_index
:
with
torch
.
no_grad
():
self
.
weight
[
self
.
padding_idx
-
self
.
vocab_start_index
].
fill_
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment