Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
fceec07d
Commit
fceec07d
authored
Feb 07, 2022
by
eqy
Committed by
hubertlu-tw
Apr 14, 2022
Browse files
fix and generate docs for FusedRMSNorm (#1285)
parent
c14cfb10
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
6 deletions
+9
-6
apex/normalization/fused_layer_norm.py
apex/normalization/fused_layer_norm.py
+6
-6
docs/source/layernorm.rst
docs/source/layernorm.rst
+3
-0
No files found.
apex/normalization/fused_layer_norm.py
View file @
fceec07d
...
...
@@ -303,19 +303,19 @@ class FusedRMSNorm(torch.nn.Module):
Currently only runs on cuda() tensors.
.. math::
y = \frac{x
- \mathrm{E}[x]}{ \sqrt
{\mathrm{
Var
}[x]
+ \epsilon}
} * \gamma
+ \beta
y = \frac{x
}
{\mathrm{
RMS
}[x]} * \gamma
The
mean and standard-deviation are
calculated separately over the last
The
root-mean-square is
calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
:math:`\gamma`
and :math:`\beta` are
learnable affine transform parameter
s
of
:math:`\gamma`
is a
learnable affine transform parameter of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
.. note::
Unlike Batch Normalization and Instance Normalization, which applies
scalar scale and bias for each entire channel/plane with the
:attr:`affine` option,
Layer
Normalization applies per-element scale
and
bias
with :attr:`elementwise_affine`.
:attr:`affine` option,
RMS
Normalization applies per-element scale
with :attr:`elementwise_affine`.
This layer uses statistics computed from input data in both training and
evaluation modes.
...
...
@@ -353,7 +353,7 @@ class FusedRMSNorm(torch.nn.Module):
>>> # Activating the module
>>> output = m(input)
.. _`Layer Normalization`: https://arxiv.org/
abs/1607.06450
.. _`
Root Mean Square
Layer Normalization`: https://arxiv.org/
pdf/1910.07467.pdf
"""
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
elementwise_affine
=
True
):
...
...
docs/source/layernorm.rst
View file @
fceec07d
...
...
@@ -12,3 +12,6 @@ apex.normalization.fused_layer_norm
.. autoclass:: FusedLayerNorm
:members:
.. autoclass:: FusedRMSNorm
:members:
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment