Commit 939ab6ae authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

gelu_fast -> gelu_accurate (#571)

Summary:
This was named gelu_fast after the original implementation: https://github.com/hendrycks/GELUs/blob/master/mnist_ae.py#L62-L63

But in practice it's actually slower and uses more memory. Rename to gelu_accurate.
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/571

Differential Revision: D15317874

Pulled By: myleott

fbshipit-source-id: c96fbc89bf91b27ced1ab8d5b25a8f23f922ec24
parent 72291287
......@@ -53,7 +53,7 @@ class TransformerModel(FairseqModel):
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--activation-fn', choices=['relu', 'gelu', 'gelu_fast'],
parser.add_argument('--activation-fn', choices=['relu', 'gelu', 'gelu_accurate'],
help='Which activation function to use')
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
......
......@@ -198,7 +198,7 @@ def transformer_lm_gpt(args):
args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.decoder_final_norm = getattr(args, 'decoder_final_norm', True)
args.activation_fn = getattr(args, 'activation_fn', 'gelu_fast')
args.activation_fn = getattr(args, 'activation_fn', 'gelu')
base_lm_architecture(args)
......@@ -211,7 +211,7 @@ def transformer_lm_gpt2_small(args):
args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.decoder_final_norm = getattr(args, 'decoder_final_norm', True)
args.activation_fn = getattr(args, 'activation_fn', 'gelu_fast')
args.activation_fn = getattr(args, 'activation_fn', 'gelu')
base_lm_architecture(args)
......@@ -224,7 +224,7 @@ def transformer_lm_gpt2_medium(args):
args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.decoder_final_norm = getattr(args, 'decoder_final_norm', True)
args.activation_fn = getattr(args, 'activation_fn', 'gelu_fast')
args.activation_fn = getattr(args, 'activation_fn', 'gelu')
base_lm_architecture(args)
......@@ -237,5 +237,5 @@ def transformer_lm_gpt2_big(args):
args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.decoder_final_norm = getattr(args, 'decoder_final_norm', True)
args.activation_fn = getattr(args, 'activation_fn', 'gelu_fast')
args.activation_fn = getattr(args, 'activation_fn', 'gelu')
base_lm_architecture(args)
......@@ -13,7 +13,7 @@ from .character_token_embedder import CharacterTokenEmbedder
from .conv_tbc import ConvTBC
from .downsampled_multihead_attention import DownsampledMultiHeadAttention
from .dynamic_convolution import DynamicConv1dTBC
from .gelu import gelu, gelu_fast
from .gelu import gelu, gelu_accurate
from .grad_multiply import GradMultiply
from .highway import Highway
from .layer_norm import LayerNorm
......@@ -40,7 +40,7 @@ __all__ = [
'DownsampledMultiHeadAttention',
'DynamicConv1dTBC',
'gelu',
'gelu_fast',
'gelu_accurate',
'GradMultiply',
'Highway',
'LayerNorm',
......
......@@ -14,10 +14,10 @@ import math
import torch
def gelu_fast(x):
if not hasattr(gelu_fast, "_a"):
gelu_fast._a = math.sqrt(2 / math.pi)
return 0.5 * x * (1 + torch.tanh(gelu_fast._a * (x + 0.044715 * torch.pow(x, 3))))
def gelu_accurate(x):
if not hasattr(gelu_accurate, "_a"):
gelu_accurate._a = math.sqrt(2 / math.pi)
return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
def gelu(x: torch.Tensor) -> torch.Tensor:
......
......@@ -16,7 +16,7 @@ import warnings
import torch
import torch.nn.functional as F
from fairseq.modules import gelu, gelu_fast
from fairseq.modules import gelu, gelu_accurate
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
......@@ -298,6 +298,9 @@ def get_activation_fn(activation: str) -> Callable:
elif activation == 'gelu':
return gelu
elif activation == 'gelu_fast':
return gelu_fast
deprecation_warning('--activation-fn=gelu_fast has been renamed to gelu_accurate')
return gelu_accurate
elif activation == 'gelu_accurate':
return gelu_accurate
else:
raise RuntimeError(f"--activation-fn {activation} not supported")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment