Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
b8899e09
Unverified
Commit
b8899e09
authored
Apr 14, 2022
by
アマデウス
Committed by
GitHub
Apr 14, 2022
Browse files
[TP] allow layernorm without bias (#750)
parent
3d7dc46d
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
134 additions
and
53 deletions
+134
-53
colossalai/nn/layer/colossalai_layer/normalization.py
colossalai/nn/layer/colossalai_layer/normalization.py
+13
-4
colossalai/nn/layer/parallel_1d/layers.py
colossalai/nn/layer/parallel_1d/layers.py
+12
-13
colossalai/nn/layer/parallel_2d/layers.py
colossalai/nn/layer/parallel_2d/layers.py
+19
-8
colossalai/nn/layer/parallel_2p5d/layers.py
colossalai/nn/layer/parallel_2p5d/layers.py
+19
-9
colossalai/nn/layer/parallel_3d/_operation.py
colossalai/nn/layer/parallel_3d/_operation.py
+17
-9
colossalai/nn/layer/parallel_3d/layers.py
colossalai/nn/layer/parallel_3d/layers.py
+14
-6
colossalai/nn/layer/vanilla/__init__.py
colossalai/nn/layer/vanilla/__init__.py
+6
-4
colossalai/nn/layer/vanilla/layers.py
colossalai/nn/layer/vanilla/layers.py
+34
-0
No files found.
colossalai/nn/layer/colossalai_layer/normalization.py
View file @
b8899e09
...
...
@@ -6,9 +6,16 @@ from ..parallel_2d import LayerNorm2D
from
..parallel_2p5d
import
LayerNorm2p5D
from
..parallel_3d
import
LayerNorm3D
from
..utils
import
get_tensor_parallel_mode
from
..vanilla
import
VanillaLayerNorm
from
._utils
import
ColossalaiModule
_parallel_layernorm
=
{
'1d'
:
LayerNorm1D
,
'2d'
:
LayerNorm2D
,
'2.5d'
:
LayerNorm2p5D
,
'3d'
:
LayerNorm3D
}
_parallel_layernorm
=
{
None
:
VanillaLayerNorm
,
"1d"
:
LayerNorm1D
,
"2d"
:
LayerNorm2D
,
"2.5d"
:
LayerNorm2p5D
,
"3d"
:
LayerNorm3D
,
}
class
LayerNorm
(
ColossalaiModule
):
...
...
@@ -16,14 +23,16 @@ class LayerNorm(ColossalaiModule):
Args:
normalized_shape (int): input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
def
__init__
(
self
,
normalized_shape
:
int
,
eps
=
1e-05
,
dtype
=
None
)
->
None
:
def
__init__
(
self
,
normalized_shape
:
int
,
eps
=
1e-05
,
bias
=
True
,
dtype
=
None
)
->
None
:
tensor_parallel
=
get_tensor_parallel_mode
()
if
tensor_parallel
is
None
:
norm
=
nn
.
LayerNorm
(
normalized_shape
,
eps
=
eps
).
to
(
dtype
).
to
(
get_current_device
())
...
...
colossalai/nn/layer/parallel_1d/layers.py
View file @
b8899e09
...
...
@@ -19,7 +19,7 @@ from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_
from
colossalai.utils.cuda
import
get_current_device
from
torch
import
Tensor
from
torch.nn.parameter
import
Parameter
from
..vanilla
import
VanillaPatchEmbedding
from
..vanilla
import
VanillaPatchEmbedding
,
VanillaLayerNorm
from
..base_layer
import
ParallelLayer
from
..colossalai_layer._utils
import
ColossalaiModule
...
...
@@ -85,20 +85,19 @@ class LayerNorm1D(ColossalaiModule):
r
"""
Layer Normalization for colossalai
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05
:type eps: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
Args:
normalized_shape (int): input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
def
__init__
(
self
,
normalized_shape
:
int
,
eps
=
1e-05
,
dtype
=
None
):
norm
=
LayerNorm
(
normalized_shape
,
eps
=
eps
,
device
=
get_current_device
()
,
dtype
=
dtype
)
def
__init__
(
self
,
normalized_shape
:
int
,
eps
=
1e-05
,
bias
=
True
,
dtype
=
None
):
norm
=
Vanilla
LayerNorm
(
normalized_shape
,
eps
=
eps
,
bias
=
bias
,
dtype
=
dtype
)
super
().
__init__
(
norm
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
):
...
...
colossalai/nn/layer/parallel_2d/layers.py
View file @
b8899e09
...
...
@@ -216,10 +216,11 @@ class LayerNorm2D(ParallelLayer):
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
def
__init__
(
self
,
normalized_shape
:
int
,
eps
:
float
=
1e-05
,
dtype
=
None
):
def
__init__
(
self
,
normalized_shape
:
int
,
eps
:
float
=
1e-05
,
bias
=
True
,
dtype
=
None
):
super
().
__init__
()
# layer norm config
...
...
@@ -239,13 +240,17 @@ class LayerNorm2D(ParallelLayer):
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
ones
(
self
.
partitioned_partition
,
**
factory_kwargs
))
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
partitioned_partition
,
**
factory_kwargs
))
if
bias
:
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
partitioned_partition
,
**
factory_kwargs
))
else
:
self
.
bias
=
None
self
.
_set_tensor_parallel_attributes
()
def
_set_tensor_parallel_attributes
(
self
):
set_tensor_parallel_attribute_by_partition
(
self
.
weight
,
self
.
summa_dim
**
2
)
set_tensor_parallel_attribute_by_partition
(
self
.
bias
,
self
.
summa_dim
**
2
)
if
self
.
bias
is
not
None
:
set_tensor_parallel_attribute_by_partition
(
self
.
bias
,
self
.
summa_dim
**
2
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
...
...
@@ -294,7 +299,9 @@ class LayerNorm2D(ParallelLayer):
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
,
bias_key
:
self
.
bias
})
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
})
if
self
.
bias
is
not
None
:
local_state
[
bias_key
]
=
self
.
bias
# gather in column groups
local_state
=
gather_tensor_parallel_state_dict
(
...
...
@@ -345,13 +352,17 @@ class LayerNorm2D(ParallelLayer):
output
=
layernorm_2d
(
x
,
E_x
,
Var_x
,
self
.
normalized_shape
,
ParallelMode
.
PARALLEL_2D_ROW
,
ParallelMode
.
PARALLEL_2D_COL
)
bias
=
add_bias_2d
(
None
,
self
.
bias
,
self
.
partitioned_partition
,
self
.
row_rank
,
self
.
col_rank
,
ParallelMode
.
PARALLEL_2D_ROW
,
ParallelMode
.
PARALLEL_2D_COL
,
True
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
scale
=
add_bias_2d
(
None
,
self
.
weight
,
self
.
partitioned_partition
,
self
.
row_rank
,
self
.
col_rank
,
ParallelMode
.
PARALLEL_2D_ROW
,
ParallelMode
.
PARALLEL_2D_COL
,
True
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
output
=
torch
.
addcmul
(
bias
,
scale
,
output
)
if
self
.
bias
is
not
None
:
bias
=
add_bias_2d
(
None
,
self
.
bias
,
self
.
partitioned_partition
,
self
.
row_rank
,
self
.
col_rank
,
ParallelMode
.
PARALLEL_2D_ROW
,
ParallelMode
.
PARALLEL_2D_COL
,
True
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
output
=
torch
.
addcmul
(
bias
,
scale
,
output
)
else
:
output
=
torch
.
mul
(
scale
,
output
)
return
output
...
...
colossalai/nn/layer/parallel_2p5d/layers.py
View file @
b8899e09
...
...
@@ -235,10 +235,11 @@ class LayerNorm2p5D(ParallelLayer):
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
def
__init__
(
self
,
normalized_shape
:
int
,
eps
:
float
=
1e-05
,
dtype
=
None
):
def
__init__
(
self
,
normalized_shape
:
int
,
eps
:
float
=
1e-05
,
bias
=
True
,
dtype
=
None
):
super
().
__init__
()
# layer norm config
...
...
@@ -259,13 +260,17 @@ class LayerNorm2p5D(ParallelLayer):
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
self
.
weight
=
Parameter
(
torch
.
ones
(
self
.
partitioned_partition
,
**
factory_kwargs
))
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
partitioned_partition
,
**
factory_kwargs
))
if
bias
:
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
partitioned_partition
,
**
factory_kwargs
))
else
:
self
.
bias
=
None
self
.
_set_tensor_parallel_attribute
()
def
_set_tensor_parallel_attribute
(
self
):
set_tensor_parallel_attribute_by_partition
(
self
.
weight
,
self
.
tesseract_dim
)
set_tensor_parallel_attribute_by_partition
(
self
.
bias
,
self
.
tesseract_dim
)
if
self
.
bias
is
not
None
:
set_tensor_parallel_attribute_by_partition
(
self
.
bias
,
self
.
tesseract_dim
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
...
...
@@ -314,7 +319,9 @@ class LayerNorm2p5D(ParallelLayer):
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
,
bias_key
:
self
.
bias
})
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
})
if
self
.
bias
is
not
None
:
local_state
[
bias_key
]
=
self
.
bias
# gather in column groups
local_state
=
gather_tensor_parallel_state_dict
(
...
...
@@ -364,15 +371,18 @@ class LayerNorm2p5D(ParallelLayer):
Var_x
=
1.0
/
torch
.
sqrt
(
Var_x
+
self
.
variance_epsilon
)
output
=
layernorm_2p5d
(
x
,
E_x
,
Var_x
,
self
.
normalized_shape
,
ParallelMode
.
PARALLEL_2P5D_ROW
)
bias
=
add_bias_2p5d
(
None
,
self
.
bias
,
self
.
partitioned_partition
,
self
.
tesseract_dim
,
self
.
row_rank
,
self
.
col_rank
,
self
.
dep_rank
,
ParallelMode
.
PARALLEL_2P5D_COL
,
True
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
scale
=
add_bias_2p5d
(
None
,
self
.
weight
,
self
.
partitioned_partition
,
self
.
tesseract_dim
,
self
.
row_rank
,
self
.
col_rank
,
self
.
dep_rank
,
ParallelMode
.
PARALLEL_2P5D_COL
,
True
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
output
=
torch
.
addcmul
(
bias
,
scale
,
output
)
if
self
.
bias
is
not
None
:
bias
=
add_bias_2p5d
(
None
,
self
.
bias
,
self
.
partitioned_partition
,
self
.
tesseract_dim
,
self
.
row_rank
,
self
.
col_rank
,
self
.
dep_rank
,
ParallelMode
.
PARALLEL_2P5D_COL
,
True
,
self
.
data_parallel_rank
,
self
.
pipeline_parallel_rank
,
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
)
output
=
torch
.
addcmul
(
bias
,
scale
,
output
)
else
:
output
=
torch
.
mul
(
scale
,
output
)
return
output
...
...
colossalai/nn/layer/parallel_3d/_operation.py
View file @
b8899e09
...
...
@@ -190,7 +190,7 @@ class _Layernorm3D(torch.autograd.Function):
@
staticmethod
@
custom_fwd
(
cast_inputs
=
torch
.
float32
)
def
forward
(
ctx
,
input_
:
Tensor
,
weight
:
Tensor
,
bias
:
Tensor
,
normalized_shape
:
int
,
eps
:
float
,
def
forward
(
ctx
,
input_
:
Tensor
,
weight
:
Tensor
,
bias
:
Optional
[
Tensor
]
,
normalized_shape
:
int
,
eps
:
float
,
input_parallel_mode
:
ParallelMode
,
weight_parallel_mode
:
ParallelMode
,
output_parallel_mode
:
ParallelMode
)
->
Tensor
:
mean
=
all_reduce
(
torch
.
sum
(
input_
,
dim
=-
1
,
keepdim
=
True
),
output_parallel_mode
)
/
normalized_shape
...
...
@@ -201,8 +201,11 @@ class _Layernorm3D(torch.autograd.Function):
ctx
.
save_for_backward
(
mu
,
sigma
,
weight
)
z
=
mu
/
sigma
output
=
weight
*
z
+
bias
output
=
weight
*
z
if
bias
is
not
None
:
output
=
output
+
bias
ctx
.
use_bias
=
bias
is
not
None
ctx
.
normalized_shape
=
normalized_shape
ctx
.
input_parallel_mode
=
input_parallel_mode
ctx
.
weight_parallel_mode
=
weight_parallel_mode
...
...
@@ -215,12 +218,17 @@ class _Layernorm3D(torch.autograd.Function):
def
backward
(
ctx
,
output_grad
:
Tensor
)
->
Tuple
[
Tensor
,
...]:
mu
,
sigma
,
weight
=
ctx
.
saved_tensors
with
torch
.
no_grad
():
bias_grad
,
weight_grad
=
output_grad
,
output_grad
*
mu
/
sigma
grads
=
torch
.
stack
([
bias_grad
,
weight_grad
]).
contiguous
()
grads
=
torch
.
sum
(
grads
,
dim
=
tuple
(
range
(
len
(
grads
.
shape
))[
1
:
-
1
]))
grads
=
all_reduce
(
grads
,
ctx
.
weight_parallel_mode
)
grads
=
all_reduce
(
grads
,
ctx
.
input_parallel_mode
)
bias_grad
,
weight_grad
=
grads
[
0
],
grads
[
1
]
weight_grad
=
output_grad
*
mu
/
sigma
if
ctx
.
use_bias
:
bias_grad
=
output_grad
weight_grad
=
torch
.
stack
([
bias_grad
,
weight_grad
]).
contiguous
()
else
:
bias_grad
=
None
weight_grad
=
torch
.
sum
(
weight_grad
,
dim
=
tuple
(
range
(
len
(
weight_grad
.
shape
))[
1
:
-
1
]))
weight_grad
=
all_reduce
(
weight_grad
,
ctx
.
weight_parallel_mode
)
weight_grad
=
all_reduce
(
weight_grad
,
ctx
.
input_parallel_mode
)
if
ctx
.
use_bias
:
bias_grad
,
weight_grad
=
weight_grad
[
0
],
weight_grad
[
1
]
dz
=
output_grad
*
weight
dvar
=
dz
*
mu
*
(
-
0.5
)
*
sigma
**
(
-
3
)
...
...
@@ -234,7 +242,7 @@ class _Layernorm3D(torch.autograd.Function):
return
input_grad
,
weight_grad
,
bias_grad
,
None
,
None
,
None
,
None
,
None
def
layernorm_3d
(
input_
:
Tensor
,
weight
:
Tensor
,
bias
:
Tensor
,
normalized_shape
:
int
,
eps
:
float
,
def
layernorm_3d
(
input_
:
Tensor
,
weight
:
Tensor
,
bias
:
Optional
[
Tensor
]
,
normalized_shape
:
int
,
eps
:
float
,
input_parallel_mode
:
ParallelMode
,
weight_parallel_mode
:
ParallelMode
,
output_parallel_mode
:
ParallelMode
)
->
Tensor
:
r
"""3D parallel Layernorm.
...
...
colossalai/nn/layer/parallel_3d/layers.py
View file @
b8899e09
...
...
@@ -36,10 +36,11 @@ class LayerNorm3D(ParallelLayer):
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float, optional): a value added to the denominator for numerical stability, defaults to 1e-12.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
def
__init__
(
self
,
normalized_shape
:
int
,
eps
:
float
=
1e-12
,
dtype
=
None
):
def
__init__
(
self
,
normalized_shape
:
int
,
eps
:
float
=
1e-12
,
bias
=
True
,
dtype
=
None
):
super
().
__init__
()
self
.
input_parallel_mode
=
get_parallel_mode_from_env
(
INPUT_GROUP_3D
)
...
...
@@ -51,18 +52,23 @@ class LayerNorm3D(ParallelLayer):
self
.
weight
=
Parameter
(
torch
.
ones
(
self
.
normalized_shape_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
))
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
normalized_shape_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
))
if
bias
:
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
normalized_shape_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
))
else
:
self
.
bias
=
None
self
.
variance_epsilon
=
eps
self
.
_set_tensor_parallel_attributes
()
def
_set_tensor_parallel_attributes
(
self
)
->
None
:
set_tensor_parallel_attribute_by_partition
(
self
.
weight
,
self
.
depth
)
set_tensor_parallel_attribute_by_partition
(
self
.
bias
,
self
.
depth
)
if
self
.
bias
is
not
None
:
set_tensor_parallel_attribute_by_partition
(
self
.
bias
,
self
.
depth
)
def
reset_parameters
(
self
)
->
None
:
init
.
zeros_
()(
self
.
bias
)
init
.
ones_
()(
self
.
weight
)
if
self
.
bias
is
not
None
:
init
.
zeros_
()(
self
.
bias
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
local_state
=
OrderedDict
()
...
...
@@ -104,7 +110,9 @@ class LayerNorm3D(ParallelLayer):
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
weight_key
=
prefix
+
'weight'
bias_key
=
prefix
+
'bias'
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
,
bias_key
:
self
.
bias
})
local_state
=
OrderedDict
({
weight_key
:
self
.
weight
})
if
self
.
bias
is
not
None
:
local_state
[
bias_key
]
=
self
.
bias
# gather in output groups
if
gpc
.
get_local_rank
(
self
.
input_parallel_mode
)
==
0
and
\
...
...
colossalai/nn/layer/vanilla/__init__.py
View file @
b8899e09
from
.layers
import
DropPath
,
VanillaClassifier
,
Vanilla
PatchEmbedding
,
\
WrappedDropout
,
WrappedDropPath
from
.layers
import
(
DropPath
,
VanillaClassifier
,
Vanilla
LayerNorm
,
VanillaPatchEmbedding
,
WrappedDropout
,
WrappedDropPath
)
__all__
=
[
'VanillaPatchEmbedding'
,
'VanillaClassifier'
,
'DropPath'
,
'WrappedDropout'
,
'WrappedDropPath'
]
__all__
=
[
"VanillaLayerNorm"
,
"VanillaPatchEmbedding"
,
"VanillaClassifier"
,
"DropPath"
,
"WrappedDropout"
,
"WrappedDropPath"
]
colossalai/nn/layer/vanilla/layers.py
View file @
b8899e09
...
...
@@ -254,3 +254,37 @@ class VanillaClassifier(nn.Module):
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
return
F
.
linear
(
input_
,
self
.
weight
,
self
.
bias
)
@
LAYERS
.
register_module
class
VanillaLayerNorm
(
nn
.
Module
):
r
"""
Layer Normalization for colossalai
Args:
normalized_shape (int): input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
def
__init__
(
self
,
normalized_shape
:
int
,
eps
=
1e-05
,
bias
=
True
,
dtype
=
None
):
super
().
__init__
()
self
.
normalized_shape
=
(
normalized_shape
,)
self
.
variance_epsilon
=
eps
factory_kwargs
=
{
'device'
:
get_current_device
(),
'dtype'
:
dtype
}
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
normalized_shape
,
**
factory_kwargs
))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
normalized_shape
,
**
factory_kwargs
))
else
:
self
.
bias
=
None
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
return
F
.
layer_norm
(
x
,
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
self
.
variance_epsilon
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment