"examples/git@developer.sourcefind.cn:OpenDAS/lmdeploy.git" did not exist on "2a4754785144a08f1e1feeb11fad87bbd6e41610"
Unverified Commit a786ca0c authored by eqy's avatar eqy Committed by GitHub
Browse files

fix and generate docs for FusedRMSNorm (#1285)

parent 684c4733
...@@ -303,19 +303,19 @@ class FusedRMSNorm(torch.nn.Module): ...@@ -303,19 +303,19 @@ class FusedRMSNorm(torch.nn.Module):
Currently only runs on cuda() tensors. Currently only runs on cuda() tensors.
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta y = \frac{x}{\mathrm{RMS}[x]} * \gamma
The mean and standard-deviation are calculated separately over the last The root-mean-square is calculated separately over the last
certain number dimensions which have to be of the shape specified by certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`. :attr:`normalized_shape`.
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of :math:`\gamma` is a learnable affine transform parameter of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``. :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
.. note:: .. note::
Unlike Batch Normalization and Instance Normalization, which applies Unlike Batch Normalization and Instance Normalization, which applies
scalar scale and bias for each entire channel/plane with the scalar scale and bias for each entire channel/plane with the
:attr:`affine` option, Layer Normalization applies per-element scale and :attr:`affine` option, RMS Normalization applies per-element scale
bias with :attr:`elementwise_affine`. with :attr:`elementwise_affine`.
This layer uses statistics computed from input data in both training and This layer uses statistics computed from input data in both training and
evaluation modes. evaluation modes.
...@@ -353,7 +353,7 @@ class FusedRMSNorm(torch.nn.Module): ...@@ -353,7 +353,7 @@ class FusedRMSNorm(torch.nn.Module):
>>> # Activating the module >>> # Activating the module
>>> output = m(input) >>> output = m(input)
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450 .. _`Root Mean Square Layer Normalization`: https://arxiv.org/pdf/1910.07467.pdf
""" """
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
......
...@@ -12,3 +12,6 @@ apex.normalization.fused_layer_norm ...@@ -12,3 +12,6 @@ apex.normalization.fused_layer_norm
.. autoclass:: FusedLayerNorm .. autoclass:: FusedLayerNorm
:members: :members:
.. autoclass:: FusedRMSNorm
:members:
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment