Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
b8899e09
Unverified
Commit
b8899e09
authored
Apr 14, 2022
by
アマデウス
Committed by
GitHub
Apr 14, 2022
Browse files
[TP] allow layernorm without bias (#750)
parent
3d7dc46d
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
134 additions
and
53 deletions
+134
-53
colossalai/nn/layer/colossalai_layer/normalization.py
colossalai/nn/layer/colossalai_layer/normalization.py
+13
-4
colossalai/nn/layer/parallel_1d/layers.py
colossalai/nn/layer/parallel_1d/layers.py
+12
-13
colossalai/nn/layer/parallel_2d/layers.py
colossalai/nn/layer/parallel_2d/layers.py
+19
-8
colossalai/nn/layer/parallel_2p5d/layers.py
colossalai/nn/layer/parallel_2p5d/layers.py
+19
-9
colossalai/nn/layer/parallel_3d/_operation.py
colossalai/nn/layer/parallel_3d/_operation.py
+17
-9
colossalai/nn/layer/parallel_3d/layers.py
colossalai/nn/layer/parallel_3d/layers.py
+14
-6
colossalai/nn/layer/vanilla/__init__.py
colossalai/nn/layer/vanilla/__init__.py
+6
-4
colossalai/nn/layer/vanilla/layers.py
colossalai/nn/layer/vanilla/layers.py
+34
-0
No files found.
colossalai/nn/layer/colossalai_layer/normalization.py
View file @
b8899e09
...
@@ -6,9 +6,16 @@ from ..parallel_2d import LayerNorm2D
...
@@ -6,9 +6,16 @@ from ..parallel_2d import LayerNorm2D
from
..parallel_2p5d
import
LayerNorm2p5D
from
..parallel_2p5d
import
LayerNorm2p5D
from
..parallel_3d
import
LayerNorm3D
from
..parallel_3d
import
LayerNorm3D
from
..utils
import
get_tensor_parallel_mode
from
..utils
import
get_tensor_parallel_mode
from
..vanilla
import
VanillaLayerNorm
from
._utils
import
ColossalaiModule
from
._utils
import
ColossalaiModule
_parallel_layernorm
=
{
'1d'
:
LayerNorm1D
,
'2d'
:
LayerNorm2D
,
'2.5d'
:
LayerNorm2p5D
,
'3d'
:
LayerNorm3D
}
_parallel_layernorm
=
{
None
:
VanillaLayerNorm
,
"1d"
:
LayerNorm1D
,
"2d"
:
LayerNorm2D
,
"2.5d"
:
LayerNorm2p5D
,
"3d"
:
LayerNorm3D
,
}
class
LayerNorm
(
ColossalaiModule
):
class
LayerNorm
(
ColossalaiModule
):
...
@@ -16,14 +23,16 @@ class LayerNorm(ColossalaiModule):
...
@@ -16,14 +23,16 @@ class LayerNorm(ColossalaiModule):
Args:
Args:
normalized_shape (int): input shape from an expected input of size.
normalized_shape (int): input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
normalize over the last dimension which is expected to be of that specific size.
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
"""
def
__init__
(
self
,
normalized_shape
:
int
,
eps
=
1e-05
,
dtype
=
None
)
->
None
:
def
__init__
(
self
,
normalized_shape
:
int
,
eps
=
1e-05
,
bias
=
True
,
dtype
=
None
)
->
None
:
tensor_parallel
=
get_tensor_parallel_mode
()
tensor_parallel
=
get_tensor_parallel_mode
()
if
tensor_parallel
is
None
:
if
tensor_parallel
is
None
:
norm
=
nn
.
LayerNorm
(
normalized_shape
,
eps
=
eps
).
to
(
dtype
).
to
(
get_current_device
())
norm
=
nn
.
LayerNorm
(
normalized_shape
,
eps
=
eps
).
to
(
dtype
).
to
(
get_current_device
())
...
...
colossalai/nn/layer/parallel_1d/layers.py
View file @
b8899e09
...
@@ -19,7 +19,7 @@ from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_
...
@@ -19,7 +19,7 @@ from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
..vanilla
import
VanillaPatchEmbedding
from
..vanilla
import
VanillaPatchEmbedding
,
VanillaLayerNorm
from
..base_layer
import
ParallelLayer
from
..base_layer
import
ParallelLayer
from
..colossalai_layer._utils
import
ColossalaiModule
from
..colossalai_layer._utils
import
ColossalaiModule
...
@@ -85,20 +85,19 @@ class LayerNorm1D(ColossalaiModule):
...
@@ -85,20 +85,19 @@ class LayerNorm1D(ColossalaiModule):
r
"""
r
"""
Layer Normalization for colossalai
Layer Normalization for colossalai
:param normalized_shape: input shape from an expected input
Args:
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
normalized_shape (int): input shape from an expected input of size.
\times \ldots \times \text{normalized_shape}[-1]]`
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
If a single integer is used, it is treated as a singleton list, and this module will
\times \ldots \times \text{normalized_shape}[-1]]`
normalize over the last dimension which is expected to be of that specific size.
If a single integer is used, it is treated as a singleton list, and this module will
:type normalized_shape: int
normalize over the last dimension which is expected to be of that specific size.
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
:type eps: float, optional
bias (bool, optional): Whether to add a bias, defaults to ``True``.
:param dtype: The dtype of parameters, defaults to None
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
:type dtype: torch.dtype, optional
"""
"""
def
__init__
(
self
,
normalized_shape
:
int
,
eps
=
1e-05
,
dtype
=
None
):
def
__init__
(
self
,
normalized_shape
:
int
,
eps
=
1e-05
,
bias
=
True
,
dtype
=
None
):
norm
=
LayerNorm
(
normalized_shape
,
eps
=
eps
,
device
=
get_current_device
()
,
dtype
=
dtype
)
norm
=
Vanilla
LayerNorm
(
normalized_shape
,
eps
=
eps
,
bias
=
bias
,
dtype
=
dtype
)
super
().
__init__
(
norm
)
super
().
__init__
(
norm
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
):
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
):
...
...
colossalai/nn/layer/parallel_2d/layers.py
View file @
b8899e09
...
@@ -216,10 +216,11 @@ class LayerNorm2D(ParallelLayer):
...
@@ -216,10 +216,11 @@ class LayerNorm2D(ParallelLayer):
If a single integer is used, it is treated as a singleton list, and this module will
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
normalize over the last dimension which is expected to be of that specific size.
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05.
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
"""
def
__init__
(
self
,
normalized_shape
:
int
,
eps
:
float
=
1e-05
,
dtype
=
None
):
def
__init__
(
self
,
normalized_shape
:
int
,
eps
:
float
=
1e-05
,
bias
=
True
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
# layer norm config
# layer norm config
...
@@ -239,13 +240,17 @@ class LayerNorm2D(ParallelLayer):
...
@@ -239,13 +240,17 @@ class LayerNorm2D(ParallelLayer):
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
ones
(
self
.
partitioned_partition
,
**
factory_kwargs
))
self
.
weight
=
Parameter
(
torch
.
ones
(
self
.
partitioned_partition
,
**
factory_kwargs
))
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
partitioned_partition
,
**
factory_kwargs
))
if
bias
:
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
partitioned_partition
,
**
factory_kwargs
))
else
:
self
.
bias
=
None
self
.
_set_tensor_parallel_attributes
()
self
.
_set_tensor_parallel_attributes
()
def
_set_tensor_parallel_attributes
(
self
):
def
_set_tensor_parallel_attributes
(
self
):
set_tensor_parallel_attribute_by_partition
(
self
.
weight
,
self
.
summa_dim
**
2
)
set_tensor_parallel_attribute_by_partition
(
self
.
weight
,
self
.
summa_dim
**
2
)
set_tensor_parallel_attribute_by_partition
(
self
.
bias
,
self
.
summa_dim
**
2
)
if
self
.
bias
is
not
None
:
set_tensor_parallel_attribute_by_partition
(
self
.
bias
,
self
.
summa_dim
**
2
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
local_state
=
OrderedDict
()
...
@@ -294,7 +299,9 @@ class LayerNorm2D(ParallelLayer):
...
@@ -294,7 +299,9 @@ class LayerNorm2D(ParallelLayer):
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
bias_key
=
prefix
+
'bias'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
,
bias_key
:
self
.
bias
})
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
})
if
self
.
bias
is
not
None
:
local_state
[
bias_key
]
=
self
.
bias
# gather in column groups
# gather in column groups
local_state
=
gather_tensor_parallel_state_dict
(
local_state
=
gather_tensor_parallel_state_dict
(
...
@@ -345,13 +352,17 @@ class LayerNorm2D(ParallelLayer):
...
@@ -345,13 +352,17 @@ class LayerNorm2D(ParallelLayer):
output
=
layernorm_2d
(
x
,
E_x
,
Var_x
,
self
.
normalized_shape
,
ParallelMode
.
PARALLEL_2D_ROW
,
output
=
layernorm_2d
(
x
,
E_x
,
Var_x
,
self
.
normalized_shape
,
ParallelMode
.
PARALLEL_2D_ROW
,
ParallelMode
.
PARALLEL_2D_COL
)
ParallelMode
.
PARALLEL_2D_COL
)
bias
=
add_bias_2d
(
None
,
self
.
bias
,
self
.
partitioned_partition
,
self
.
row_rank
,
self
.
col_rank
,
ParallelMode
.
PARALLEL_2D_ROW
,
ParallelMode
.
PARALLEL_2D_COL
,
True
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
scale
=
add_bias_2d
(
None
,
self
.
weight
,
self
.
partitioned_partition
,
self
.
row_rank
,
self
.
col_rank
,
scale
=
add_bias_2d
(
None
,
self
.
weight
,
self
.
partitioned_partition
,
self
.
row_rank
,
self
.
col_rank
,
ParallelMode
.
PARALLEL_2D_ROW
,
ParallelMode
.
PARALLEL_2D_COL
,
True
,
self
.
data_parallel_rank
,
ParallelMode
.
PARALLEL_2D_ROW
,
ParallelMode
.
PARALLEL_2D_COL
,
True
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
output
=
torch
.
addcmul
(
bias
,
scale
,
output
)
if
self
.
bias
is
not
None
:
bias
=
add_bias_2d
(
None
,
self
.
bias
,
self
.
partitioned_partition
,
self
.
row_rank
,
self
.
col_rank
,
ParallelMode
.
PARALLEL_2D_ROW
,
ParallelMode
.
PARALLEL_2D_COL
,
True
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
output
=
torch
.
addcmul
(
bias
,
scale
,
output
)
else
:
output
=
torch
.
mul
(
scale
,
output
)
return
output
return
output
...
...
colossalai/nn/layer/parallel_2p5d/layers.py
View file @
b8899e09
...
@@ -235,10 +235,11 @@ class LayerNorm2p5D(ParallelLayer):
...
@@ -235,10 +235,11 @@ class LayerNorm2p5D(ParallelLayer):
If a single integer is used, it is treated as a singleton list, and this module will
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
normalize over the last dimension which is expected to be of that specific size.
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05.
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
"""
def
__init__
(
self
,
normalized_shape
:
int
,
eps
:
float
=
1e-05
,
dtype
=
None
):
def
__init__
(
self
,
normalized_shape
:
int
,
eps
:
float
=
1e-05
,
bias
=
True
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
# layer norm config
# layer norm config
...
@@ -259,13 +260,17 @@ class LayerNorm2p5D(ParallelLayer):
...
@@ -259,13 +260,17 @@ class LayerNorm2p5D(ParallelLayer):
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
ones
(
self
.
partitioned_partition
,
**
factory_kwargs
))
self
.
weight
=
Parameter
(
torch
.
ones
(
self
.
partitioned_partition
,
**
factory_kwargs
))
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
partitioned_partition
,
**
factory_kwargs
))
if
bias
:
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
partitioned_partition
,
**
factory_kwargs
))
else
:
self
.
bias
=
None
self
.
_set_tensor_parallel_attribute
()
self
.
_set_tensor_parallel_attribute
()
def
_set_tensor_parallel_attribute
(
self
):
def
_set_tensor_parallel_attribute
(
self
):
set_tensor_parallel_attribute_by_partition
(
self
.
weight
,
self
.
tesseract_dim
)
set_tensor_parallel_attribute_by_partition
(
self
.
weight
,
self
.
tesseract_dim
)
set_tensor_parallel_attribute_by_partition
(
self
.
bias
,
self
.
tesseract_dim
)
if
self
.
bias
is
not
None
:
set_tensor_parallel_attribute_by_partition
(
self
.
bias
,
self
.
tesseract_dim
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
local_state
=
OrderedDict
()
...
@@ -314,7 +319,9 @@ class LayerNorm2p5D(ParallelLayer):
...
@@ -314,7 +319,9 @@ class LayerNorm2p5D(ParallelLayer):
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
bias_key
=
prefix
+
'bias'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
,
bias_key
:
self
.
bias
})
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
})
if
self
.
bias
is
not
None
:
local_state
[
bias_key
]
=
self
.
bias
# gather in column groups
# gather in column groups
local_state
=
gather_tensor_parallel_state_dict
(
local_state
=
gather_tensor_parallel_state_dict
(
...
@@ -364,15 +371,18 @@ class LayerNorm2p5D(ParallelLayer):
...
@@ -364,15 +371,18 @@ class LayerNorm2p5D(ParallelLayer):
Var_x
=
1.0
/
torch
.
sqrt
(
Var_x
+
self
.
variance_epsilon
)
Var_x
=
1.0
/
torch
.
sqrt
(
Var_x
+
self
.
variance_epsilon
)
output
=
layernorm_2p5d
(
x
,
E_x
,
Var_x
,
self
.
normalized_shape
,
ParallelMode
.
PARALLEL_2P5D_ROW
)
output
=
layernorm_2p5d
(
x
,
E_x
,
Var_x
,
self
.
normalized_shape
,
ParallelMode
.
PARALLEL_2P5D_ROW
)
bias
=
add_bias_2p5d
(
None
,
self
.
bias
,
self
.
partitioned_partition
,
self
.
tesseract_dim
,
self
.
row_rank
,
self
.
col_rank
,
self
.
dep_rank
,
ParallelMode
.
PARALLEL_2P5D_COL
,
True
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
scale
=
add_bias_2p5d
(
None
,
self
.
weight
,
self
.
partitioned_partition
,
self
.
tesseract_dim
,
self
.
row_rank
,
scale
=
add_bias_2p5d
(
None
,
self
.
weight
,
self
.
partitioned_partition
,
self
.
tesseract_dim
,
self
.
row_rank
,
self
.
col_rank
,
self
.
dep_rank
,
ParallelMode
.
PARALLEL_2P5D_COL
,
True
,
self
.
col_rank
,
self
.
dep_rank
,
ParallelMode
.
PARALLEL_2P5D_COL
,
True
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
self
.
tensor_parallel_size
)
output
=
torch
.
addcmul
(
bias
,
scale
,
output
)
if
self
.
bias
is
not
None
:
bias
=
add_bias_2p5d
(
None
,
self
.
bias
,
self
.
partitioned_partition
,
self
.
tesseract_dim
,
self
.
row_rank
,
self
.
col_rank
,
self
.
dep_rank
,
ParallelMode
.
PARALLEL_2P5D_COL
,
True
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
output
=
torch
.
addcmul
(
bias
,
scale
,
output
)
else
:
output
=
torch
.
mul
(
scale
,
output
)
return
output
return
output
...
...
colossalai/nn/layer/parallel_3d/_operation.py
View file @
b8899e09
...
@@ -190,7 +190,7 @@ class _Layernorm3D(torch.autograd.Function):
...
@@ -190,7 +190,7 @@ class _Layernorm3D(torch.autograd.Function):
@
staticmethod
@
staticmethod
@
custom_fwd
(
cast_inputs
=
torch
.
float32
)
@
custom_fwd
(
cast_inputs
=
torch
.
float32
)
def
forward
(
ctx
,
input_
:
Tensor
,
weight
:
Tensor
,
bias
:
Tensor
,
normalized_shape
:
int
,
eps
:
float
,
def
forward
(
ctx
,
input_
:
Tensor
,
weight
:
Tensor
,
bias
:
Optional
[
Tensor
]
,
normalized_shape
:
int
,
eps
:
float
,
input_parallel_mode
:
ParallelMode
,
weight_parallel_mode
:
ParallelMode
,
input_parallel_mode
:
ParallelMode
,
weight_parallel_mode
:
ParallelMode
,
output_parallel_mode
:
ParallelMode
)
->
Tensor
:
output_parallel_mode
:
ParallelMode
)
->
Tensor
:
mean
=
all_reduce
(
torch
.
sum
(
input_
,
dim
=-
1
,
keepdim
=
True
),
output_parallel_mode
)
/
normalized_shape
mean
=
all_reduce
(
torch
.
sum
(
input_
,
dim
=-
1
,
keepdim
=
True
),
output_parallel_mode
)
/
normalized_shape
...
@@ -201,8 +201,11 @@ class _Layernorm3D(torch.autograd.Function):
...
@@ -201,8 +201,11 @@ class _Layernorm3D(torch.autograd.Function):
ctx
.
save_for_backward
(
mu
,
sigma
,
weight
)
ctx
.
save_for_backward
(
mu
,
sigma
,
weight
)
z
=
mu
/
sigma
z
=
mu
/
sigma
output
=
weight
*
z
+
bias
output
=
weight
*
z
if
bias
is
not
None
:
output
=
output
+
bias
ctx
.
use_bias
=
bias
is
not
None
ctx
.
normalized_shape
=
normalized_shape
ctx
.
normalized_shape
=
normalized_shape
ctx
.
input_parallel_mode
=
input_parallel_mode
ctx
.
input_parallel_mode
=
input_parallel_mode
ctx
.
weight_parallel_mode
=
weight_parallel_mode
ctx
.
weight_parallel_mode
=
weight_parallel_mode
...
@@ -215,12 +218,17 @@ class _Layernorm3D(torch.autograd.Function):
...
@@ -215,12 +218,17 @@ class _Layernorm3D(torch.autograd.Function):
def
backward
(
ctx
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
def
backward
(
ctx
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
mu
,
sigma
,
weight
=
ctx
.
saved_tensors
mu
,
sigma
,
weight
=
ctx
.
saved_tensors
with
torch
.
no_grad
():
with
torch
.
no_grad
():
bias_grad
,
weight_grad
=
output_grad
,
output_grad
*
mu
/
sigma
weight_grad
=
output_grad
*
mu
/
sigma
grads
=
torch
.
stack
([
bias_grad
,
weight_grad
]).
contiguous
()
if
ctx
.
use_bias
:
grads
=
torch
.
sum
(
grads
,
dim
=
tuple
(
range
(
len
(
grads
.
shape
))[
1
:
-
1
]))
bias_grad
=
output_grad
grads
=
all_reduce
(
grads
,
ctx
.
weight_parallel_mode
)
weight_grad
=
torch
.
stack
([
bias_grad
,
weight_grad
]).
contiguous
()
grads
=
all_reduce
(
grads
,
ctx
.
input_parallel_mode
)
else
:
bias_grad
,
weight_grad
=
grads
[
0
],
grads
[
1
]
bias_grad
=
None
weight_grad
=
torch
.
sum
(
weight_grad
,
dim
=
tuple
(
range
(
len
(
weight_grad
.
shape
))[
1
:
-
1
]))
weight_grad
=
all_reduce
(
weight_grad
,
ctx
.
weight_parallel_mode
)
weight_grad
=
all_reduce
(
weight_grad
,
ctx
.
input_parallel_mode
)
if
ctx
.
use_bias
:
bias_grad
,
weight_grad
=
weight_grad
[
0
],
weight_grad
[
1
]
dz
=
output_grad
*
weight
dz
=
output_grad
*
weight
dvar
=
dz
*
mu
*
(
-
0.5
)
*
sigma
**
(
-
3
)
dvar
=
dz
*
mu
*
(
-
0.5
)
*
sigma
**
(
-
3
)
...
@@ -234,7 +242,7 @@ class _Layernorm3D(torch.autograd.Function):
...
@@ -234,7 +242,7 @@ class _Layernorm3D(torch.autograd.Function):
return
input_grad
,
weight_grad
,
bias_grad
,
None
,
None
,
None
,
None
,
None
return
input_grad
,
weight_grad
,
bias_grad
,
None
,
None
,
None
,
None
,
None
def
layernorm_3d
(
input_
:
Tensor
,
weight
:
Tensor
,
bias
:
Tensor
,
normalized_shape
:
int
,
eps
:
float
,
def
layernorm_3d
(
input_
:
Tensor
,
weight
:
Tensor
,
bias
:
Optional
[
Tensor
]
,
normalized_shape
:
int
,
eps
:
float
,
input_parallel_mode
:
ParallelMode
,
weight_parallel_mode
:
ParallelMode
,
input_parallel_mode
:
ParallelMode
,
weight_parallel_mode
:
ParallelMode
,
output_parallel_mode
:
ParallelMode
)
->
Tensor
:
output_parallel_mode
:
ParallelMode
)
->
Tensor
:
r
"""3D parallel Layernorm.
r
"""3D parallel Layernorm.
...
...
colossalai/nn/layer/parallel_3d/layers.py
View file @
b8899e09
...
@@ -36,10 +36,11 @@ class LayerNorm3D(ParallelLayer):
...
@@ -36,10 +36,11 @@ class LayerNorm3D(ParallelLayer):
If a single integer is used, it is treated as a singleton list, and this module will
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
normalize over the last dimension which is expected to be of that specific size.
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-12.
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-12.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
"""
def
__init__
(
self
,
normalized_shape
:
int
,
eps
:
float
=
1e-12
,
dtype
=
None
):
def
__init__
(
self
,
normalized_shape
:
int
,
eps
:
float
=
1e-12
,
bias
=
True
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
input_parallel_mode
=
get_parallel_mode_from_env
(
INPUT_GROUP_3D
)
self
.
input_parallel_mode
=
get_parallel_mode_from_env
(
INPUT_GROUP_3D
)
...
@@ -51,18 +52,23 @@ class LayerNorm3D(ParallelLayer):
...
@@ -51,18 +52,23 @@ class LayerNorm3D(ParallelLayer):
self
.
weight
=
Parameter
(
self
.
weight
=
Parameter
(
torch
.
ones
(
self
.
normalized_shape_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
))
torch
.
ones
(
self
.
normalized_shape_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
))
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
normalized_shape_per_partition
,
device
=
get_current_device
(),
if
bias
:
dtype
=
dtype
))
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
normalized_shape_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
))
else
:
self
.
bias
=
None
self
.
variance_epsilon
=
eps
self
.
variance_epsilon
=
eps
self
.
_set_tensor_parallel_attributes
()
self
.
_set_tensor_parallel_attributes
()
def
_set_tensor_parallel_attributes
(
self
)
->
None
:
def
_set_tensor_parallel_attributes
(
self
)
->
None
:
set_tensor_parallel_attribute_by_partition
(
self
.
weight
,
self
.
depth
)
set_tensor_parallel_attribute_by_partition
(
self
.
weight
,
self
.
depth
)
set_tensor_parallel_attribute_by_partition
(
self
.
bias
,
self
.
depth
)
if
self
.
bias
is
not
None
:
set_tensor_parallel_attribute_by_partition
(
self
.
bias
,
self
.
depth
)
def
reset_parameters
(
self
)
->
None
:
def
reset_parameters
(
self
)
->
None
:
init
.
zeros_
()(
self
.
bias
)
init
.
ones_
()(
self
.
weight
)
init
.
ones_
()(
self
.
weight
)
if
self
.
bias
is
not
None
:
init
.
zeros_
()(
self
.
bias
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
local_state
=
OrderedDict
()
...
@@ -104,7 +110,9 @@ class LayerNorm3D(ParallelLayer):
...
@@ -104,7 +110,9 @@ class LayerNorm3D(ParallelLayer):
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
bias_key
=
prefix
+
'bias'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
,
bias_key
:
self
.
bias
})
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
})
if
self
.
bias
is
not
None
:
local_state
[
bias_key
]
=
self
.
bias
# gather in output groups
# gather in output groups
if
gpc
.
get_local_rank
(
self
.
input_parallel_mode
)
==
0
and
\
if
gpc
.
get_local_rank
(
self
.
input_parallel_mode
)
==
0
and
\
...
...
colossalai/nn/layer/vanilla/__init__.py
View file @
b8899e09
from
.layers
import
DropPath
,
VanillaClassifier
,
Vanilla
PatchEmbedding
,
\
from
.layers
import
(
DropPath
,
VanillaClassifier
,
Vanilla
LayerNorm
,
WrappedDropout
,
WrappedDropPath
VanillaPatchEmbedding
,
WrappedDropout
,
WrappedDropPath
)
__all__
=
[
'VanillaPatchEmbedding'
,
'VanillaClassifier'
,
'DropPath'
,
__all__
=
[
'WrappedDropout'
,
'WrappedDropPath'
]
"VanillaLayerNorm"
,
"VanillaPatchEmbedding"
,
"VanillaClassifier"
,
"DropPath"
,
"WrappedDropout"
,
"WrappedDropPath"
]
colossalai/nn/layer/vanilla/layers.py
View file @
b8899e09
...
@@ -254,3 +254,37 @@ class VanillaClassifier(nn.Module):
...
@@ -254,3 +254,37 @@ class VanillaClassifier(nn.Module):
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
return
F
.
linear
(
input_
,
self
.
weight
,
self
.
bias
)
return
F
.
linear
(
input_
,
self
.
weight
,
self
.
bias
)
@
LAYERS
.
register_module
class
VanillaLayerNorm
(
nn
.
Module
):
r
"""
Layer Normalization for colossalai
Args:
normalized_shape (int): input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
def
__init__
(
self
,
normalized_shape
:
int
,
eps
=
1e-05
,
bias
=
True
,
dtype
=
None
):
super
().
__init__
()
self
.
normalized_shape
=
(
normalized_shape
,)
self
.
variance_epsilon
=
eps
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
normalized_shape
,
**
factory_kwargs
))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
normalized_shape
,
**
factory_kwargs
))
else
:
self
.
bias
=
None
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
return
F
.
layer_norm
(
x
,
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
self
.
variance_epsilon
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment