Unverified Commit a786ca0c authored by eqy's avatar eqy Committed by GitHub
Browse files

fix and generate docs for FusedRMSNorm (#1285)

parent 684c4733
......@@ -303,19 +303,19 @@ class FusedRMSNorm(torch.nn.Module):
Currently only runs on cuda() tensors.
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
y = \frac{x}{\mathrm{RMS}[x]} * \gamma
The mean and standard-deviation are calculated separately over the last
The root-mean-square is calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
:math:`\gamma` is a learnable affine transform parameter of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
.. note::
Unlike Batch Normalization and Instance Normalization, which applies
scalar scale and bias for each entire channel/plane with the
:attr:`affine` option, Layer Normalization applies per-element scale and
bias with :attr:`elementwise_affine`.
:attr:`affine` option, RMS Normalization applies per-element scale
with :attr:`elementwise_affine`.
This layer uses statistics computed from input data in both training and
evaluation modes.
......@@ -353,7 +353,7 @@ class FusedRMSNorm(torch.nn.Module):
>>> # Activating the module
>>> output = m(input)
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
.. _`Root Mean Square Layer Normalization`: https://arxiv.org/pdf/1910.07467.pdf
"""
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
......
......@@ -12,3 +12,6 @@ apex.normalization.fused_layer_norm
.. autoclass:: FusedLayerNorm
:members:
.. autoclass:: FusedRMSNorm
:members:
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment