Unverified Commit bee361c6 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[t5/t0/mt5 models] faster/leaner custom layer norm (#14656)

* [t5] faster/leaner custom layer norm

* wip

* apex.normalization.FusedRMSNorm

* cleanup

* cleanup

* add doc

* add catch all

* Trigger CI

* expand
parent e3d1a8da
...@@ -263,6 +263,11 @@ print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True)) ...@@ -263,6 +263,11 @@ print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))
<a id='scripts'></a> <a id='scripts'></a>
## Performance
If you'd like a faster training and inference performance, install [apex](https://github.com/NVIDIA/apex#quick-start) and then the model will automatically use `apex.normalization.FusedRMSNorm` instead of `T5LayerNorm`. The former uses an optimized fused kernel which is several times faster than the latter.
## Example scripts ## Example scripts
T5 is supported by several example scripts, both for pre-training and fine-tuning. T5 is supported by several example scripts, both for pre-training and fine-tuning.
......
...@@ -237,14 +237,19 @@ DEPARALLELIZE_DOCSTRING = r""" ...@@ -237,14 +237,19 @@ DEPARALLELIZE_DOCSTRING = r"""
class T5LayerNorm(nn.Module): class T5LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """
Construct a layernorm module in the T5 style No bias and no subtraction of mean. Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
""" """
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size)) self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, hidden_states): def forward(self, hidden_states):
# layer norm should always be calculated in float32
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
# half-precision inputs is done in fp32
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
...@@ -255,6 +260,20 @@ class T5LayerNorm(nn.Module): ...@@ -255,6 +260,20 @@ class T5LayerNorm(nn.Module):
return self.weight * hidden_states return self.weight * hidden_states
try:
from apex.normalization import FusedRMSNorm
T5LayerNorm = FusedRMSNorm # noqa
logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm")
except ImportError:
# using the normal T5LayerNorm
pass
except Exception:
logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm")
pass
class T5DenseReluDense(nn.Module): class T5DenseReluDense(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment