Commit 939ab6ae authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

gelu_fast -> gelu_accurate (#571)

Summary:
This was named gelu_fast after the original implementation: https://github.com/hendrycks/GELUs/blob/master/mnist_ae.py#L62-L63

But in practice it's actually slower and uses more memory. Rename to gelu_accurate.
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/571

Differential Revision: D15317874

Pulled By: myleott

fbshipit-source-id: c96fbc89bf91b27ced1ab8d5b25a8f23f922ec24
parent 72291287
...@@ -53,7 +53,7 @@ class TransformerModel(FairseqModel): ...@@ -53,7 +53,7 @@ class TransformerModel(FairseqModel):
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
# fmt: off # fmt: off
parser.add_argument('--activation-fn', choices=['relu', 'gelu', 'gelu_fast'], parser.add_argument('--activation-fn', choices=['relu', 'gelu', 'gelu_accurate'],
help='Which activation function to use') help='Which activation function to use')
parser.add_argument('--dropout', type=float, metavar='D', parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability') help='dropout probability')
......
...@@ -198,7 +198,7 @@ def transformer_lm_gpt(args): ...@@ -198,7 +198,7 @@ def transformer_lm_gpt(args):
args.dropout = getattr(args, 'dropout', 0.1) args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1) args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.decoder_final_norm = getattr(args, 'decoder_final_norm', True) args.decoder_final_norm = getattr(args, 'decoder_final_norm', True)
args.activation_fn = getattr(args, 'activation_fn', 'gelu_fast') args.activation_fn = getattr(args, 'activation_fn', 'gelu')
base_lm_architecture(args) base_lm_architecture(args)
...@@ -211,7 +211,7 @@ def transformer_lm_gpt2_small(args): ...@@ -211,7 +211,7 @@ def transformer_lm_gpt2_small(args):
args.dropout = getattr(args, 'dropout', 0.1) args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1) args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.decoder_final_norm = getattr(args, 'decoder_final_norm', True) args.decoder_final_norm = getattr(args, 'decoder_final_norm', True)
args.activation_fn = getattr(args, 'activation_fn', 'gelu_fast') args.activation_fn = getattr(args, 'activation_fn', 'gelu')
base_lm_architecture(args) base_lm_architecture(args)
...@@ -224,7 +224,7 @@ def transformer_lm_gpt2_medium(args): ...@@ -224,7 +224,7 @@ def transformer_lm_gpt2_medium(args):
args.dropout = getattr(args, 'dropout', 0.1) args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1) args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.decoder_final_norm = getattr(args, 'decoder_final_norm', True) args.decoder_final_norm = getattr(args, 'decoder_final_norm', True)
args.activation_fn = getattr(args, 'activation_fn', 'gelu_fast') args.activation_fn = getattr(args, 'activation_fn', 'gelu')
base_lm_architecture(args) base_lm_architecture(args)
...@@ -237,5 +237,5 @@ def transformer_lm_gpt2_big(args): ...@@ -237,5 +237,5 @@ def transformer_lm_gpt2_big(args):
args.dropout = getattr(args, 'dropout', 0.1) args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1) args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.decoder_final_norm = getattr(args, 'decoder_final_norm', True) args.decoder_final_norm = getattr(args, 'decoder_final_norm', True)
args.activation_fn = getattr(args, 'activation_fn', 'gelu_fast') args.activation_fn = getattr(args, 'activation_fn', 'gelu')
base_lm_architecture(args) base_lm_architecture(args)
...@@ -13,7 +13,7 @@ from .character_token_embedder import CharacterTokenEmbedder ...@@ -13,7 +13,7 @@ from .character_token_embedder import CharacterTokenEmbedder
from .conv_tbc import ConvTBC from .conv_tbc import ConvTBC
from .downsampled_multihead_attention import DownsampledMultiHeadAttention from .downsampled_multihead_attention import DownsampledMultiHeadAttention
from .dynamic_convolution import DynamicConv1dTBC from .dynamic_convolution import DynamicConv1dTBC
from .gelu import gelu, gelu_fast from .gelu import gelu, gelu_accurate
from .grad_multiply import GradMultiply from .grad_multiply import GradMultiply
from .highway import Highway from .highway import Highway
from .layer_norm import LayerNorm from .layer_norm import LayerNorm
...@@ -40,7 +40,7 @@ __all__ = [ ...@@ -40,7 +40,7 @@ __all__ = [
'DownsampledMultiHeadAttention', 'DownsampledMultiHeadAttention',
'DynamicConv1dTBC', 'DynamicConv1dTBC',
'gelu', 'gelu',
'gelu_fast', 'gelu_accurate',
'GradMultiply', 'GradMultiply',
'Highway', 'Highway',
'LayerNorm', 'LayerNorm',
......
...@@ -14,10 +14,10 @@ import math ...@@ -14,10 +14,10 @@ import math
import torch import torch
def gelu_fast(x): def gelu_accurate(x):
if not hasattr(gelu_fast, "_a"): if not hasattr(gelu_accurate, "_a"):
gelu_fast._a = math.sqrt(2 / math.pi) gelu_accurate._a = math.sqrt(2 / math.pi)
return 0.5 * x * (1 + torch.tanh(gelu_fast._a * (x + 0.044715 * torch.pow(x, 3)))) return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
def gelu(x: torch.Tensor) -> torch.Tensor: def gelu(x: torch.Tensor) -> torch.Tensor:
......
...@@ -16,7 +16,7 @@ import warnings ...@@ -16,7 +16,7 @@ import warnings
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.modules import gelu, gelu_fast from fairseq.modules import gelu, gelu_accurate
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
...@@ -298,6 +298,9 @@ def get_activation_fn(activation: str) -> Callable: ...@@ -298,6 +298,9 @@ def get_activation_fn(activation: str) -> Callable:
elif activation == 'gelu': elif activation == 'gelu':
return gelu return gelu
elif activation == 'gelu_fast': elif activation == 'gelu_fast':
return gelu_fast deprecation_warning('--activation-fn=gelu_fast has been renamed to gelu_accurate')
return gelu_accurate
elif activation == 'gelu_accurate':
return gelu_accurate
else: else:
raise RuntimeError(f"--activation-fn {activation} not supported") raise RuntimeError(f"--activation-fn {activation} not supported")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment