Commit 8500bdd0 authored by Liezl Puzon's avatar Liezl Puzon Committed by Facebook Github Bot
Browse files

Add gelu and gelu_fast as possible activation functions (#653)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/653

After this diff, you can train a transformer model with --activation-fn 'relu', 'gelu', or 'gelu_fast'

gelu_fast is the default implementation in https://github.com/hendrycks/GELUs/blob/master/mnist_fcn.py#L72-L77
gelu is the alternate implementation in https://github.com/hendrycks/GELUs/blob/master/mnist_fcn.py#L72-L77 and the default implementation in https://github.com/facebookresearch/XLM

Reviewed By: pipibjc

Differential Revision: D14966006

fbshipit-source-id: 94e95fb99bd548ba47cf23b4999265c7b6833fc1
parent d8d03745
...@@ -136,7 +136,7 @@ class MaskedLMEncoder(FairseqEncoder): ...@@ -136,7 +136,7 @@ class MaskedLMEncoder(FairseqEncoder):
) )
encoder_normalize_before = getattr(args, 'encoder_normalize_before', False) encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
use_bert_layer_norm = getattr(args, 'bert_layer_norm', False) use_bert_layer_norm = getattr(args, 'bert_layer_norm', False)
use_gelu = getattr(args, 'use_gelu', False) use_gelu = getattr(args, 'gelu', False)
apply_bert_init = getattr(args, 'apply_bert_init', False) apply_bert_init = getattr(args, 'apply_bert_init', False)
self.sentence_encoder = TransformerSentenceEncoder( self.sentence_encoder = TransformerSentenceEncoder(
......
...@@ -52,8 +52,8 @@ class TransformerModel(FairseqModel): ...@@ -52,8 +52,8 @@ class TransformerModel(FairseqModel):
help='dropout probability') help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D', parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights') help='dropout probability for attention weights')
parser.add_argument('--relu-dropout', type=float, metavar='D', parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
help='dropout probability after ReLU in FFN') help='dropout probability after activation in FFN.')
parser.add_argument('--encoder-embed-path', type=str, metavar='STR', parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
help='path to pre-trained encoder embedding') help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N', parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
...@@ -94,6 +94,8 @@ class TransformerModel(FairseqModel): ...@@ -94,6 +94,8 @@ class TransformerModel(FairseqModel):
'Must be used with adaptive_loss criterion'), 'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections') help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--activation-fn', choices=['relu', 'gelu', 'gelu_fast'],
help='Which activation function to use')
# fmt: on # fmt: on
@classmethod @classmethod
...@@ -583,7 +585,13 @@ class TransformerEncoderLayer(nn.Module): ...@@ -583,7 +585,13 @@ class TransformerEncoderLayer(nn.Module):
) )
self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = args.dropout self.dropout = args.dropout
self.relu_dropout = args.relu_dropout self.activation_fn = utils.get_activation_fn(
activation=getattr(args, 'activation_fn', 'relu')
)
self.activation_dropout = getattr(args, 'activation_dropout', 0)
if self.activation_dropout == 0:
# for backwards compatibility with models that use args.relu_dropout
self.activation_dropout = getattr(args, 'relu_dropout', 0)
self.normalize_before = args.encoder_normalize_before self.normalize_before = args.encoder_normalize_before
self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim) self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim) self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
...@@ -627,8 +635,8 @@ class TransformerEncoderLayer(nn.Module): ...@@ -627,8 +635,8 @@ class TransformerEncoderLayer(nn.Module):
residual = x residual = x
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
x = F.relu(self.fc1(x)) x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.relu_dropout, training=self.training) x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x) x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
...@@ -668,7 +676,13 @@ class TransformerDecoderLayer(nn.Module): ...@@ -668,7 +676,13 @@ class TransformerDecoderLayer(nn.Module):
dropout=args.attention_dropout, dropout=args.attention_dropout,
) )
self.dropout = args.dropout self.dropout = args.dropout
self.relu_dropout = args.relu_dropout self.activation_fn = utils.get_activation_fn(
activation=getattr(args, 'activation_fn', 'relu')
)
self.activation_dropout = getattr(args, 'activation_dropout', 0)
if self.activation_dropout == 0:
# for backwards compatibility with models that use args.relu_dropout
self.activation_dropout = getattr(args, 'relu_dropout', 0)
self.normalize_before = args.decoder_normalize_before self.normalize_before = args.decoder_normalize_before
self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.self_attn_layer_norm = LayerNorm(self.embed_dim)
...@@ -752,8 +766,8 @@ class TransformerDecoderLayer(nn.Module): ...@@ -752,8 +766,8 @@ class TransformerDecoderLayer(nn.Module):
residual = x residual = x
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
x = F.relu(self.fc1(x)) x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.relu_dropout, training=self.training) x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x) x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
...@@ -814,6 +828,7 @@ def base_lm_architecture(args): ...@@ -814,6 +828,7 @@ def base_lm_architecture(args):
args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0) args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)
args.adaptive_softmax_factor = getattr(args, 'adaptive_softmax_factor', 4) args.adaptive_softmax_factor = getattr(args, 'adaptive_softmax_factor', 4)
args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False) args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False)
args.activation_fn = getattr(args, 'activation_fn', 'relu')
args.character_embeddings = getattr(args, 'character_embeddings', False) args.character_embeddings = getattr(args, 'character_embeddings', False)
...@@ -851,7 +866,7 @@ def transformer_lm_wiki103(args): ...@@ -851,7 +866,7 @@ def transformer_lm_wiki103(args):
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '20000,60000') args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '20000,60000')
args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0.2) args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0.2)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1) args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.relu_dropout = getattr(args, 'relu_dropout', 0.1) args.activation_dropout = getattr(args, 'activation_dropout', 0.1)
transformer_lm_big(args) transformer_lm_big(args)
...@@ -880,7 +895,8 @@ def base_architecture(args): ...@@ -880,7 +895,8 @@ def base_architecture(args):
args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', False) args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', False)
args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False) args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False)
args.attention_dropout = getattr(args, 'attention_dropout', 0.) args.attention_dropout = getattr(args, 'attention_dropout', 0.)
args.relu_dropout = getattr(args, 'relu_dropout', 0.) args.activation_dropout = getattr(args, 'activation_dropout', 0.)
args.activation_fn = getattr(args, 'activation_fn', 'relu')
args.dropout = getattr(args, 'dropout', 0.1) args.dropout = getattr(args, 'dropout', 0.1)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0) args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)
...@@ -943,5 +959,5 @@ def transformer_wmt_en_de_big_t2t(args): ...@@ -943,5 +959,5 @@ def transformer_wmt_en_de_big_t2t(args):
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', True) args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', True)
args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', True) args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', True)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1) args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.relu_dropout = getattr(args, 'relu_dropout', 0.1) args.activation_dropout = getattr(args, 'activation_dropout', 0.1)
transformer_vaswani_wmt_en_de_big(args) transformer_vaswani_wmt_en_de_big(args)
...@@ -13,6 +13,7 @@ from .character_token_embedder import CharacterTokenEmbedder ...@@ -13,6 +13,7 @@ from .character_token_embedder import CharacterTokenEmbedder
from .conv_tbc import ConvTBC from .conv_tbc import ConvTBC
from .downsampled_multihead_attention import DownsampledMultiHeadAttention from .downsampled_multihead_attention import DownsampledMultiHeadAttention
from .dynamic_convolution import DynamicConv1dTBC from .dynamic_convolution import DynamicConv1dTBC
from .gelu import gelu, gelu_fast
from .grad_multiply import GradMultiply from .grad_multiply import GradMultiply
from .highway import Highway from .highway import Highway
from .layer_norm import LayerNorm from .layer_norm import LayerNorm
...@@ -37,6 +38,8 @@ __all__ = [ ...@@ -37,6 +38,8 @@ __all__ = [
'ConvTBC', 'ConvTBC',
'DownsampledMultiHeadAttention', 'DownsampledMultiHeadAttention',
'DynamicConv1dTBC', 'DynamicConv1dTBC',
'gelu',
'gelu_fast',
'GradMultiply', 'GradMultiply',
'Highway', 'Highway',
'LayerNorm', 'LayerNorm',
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch
"""
See "Gaussian Error Linear Units (GELUs)" by Dan Hendrycks and Kevin Gimpel with
the corresponding GitHub repo: https://github.com/hendrycks/GELUs
"""
def gelu_fast(x):
if not hasattr(gelu_fast, "_a"):
gelu_fast._a = math.sqrt(2 / math.pi)
return 0.5 * x * (1 + torch.tanh(gelu_fast._a * (x + 0.044715 * torch.pow(x, 3))))
def gelu(x: torch.Tensor) -> torch.Tensor:
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
...@@ -10,14 +10,7 @@ import math ...@@ -10,14 +10,7 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.modules import MultiheadAttention, BertLayerNorm from fairseq.modules import gelu, MultiheadAttention, BertLayerNorm
def gelu(x: torch.Tensor) -> torch.Tensor:
"""
Implementation of the gelu activation function.
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
class TransformerSentenceEncoderLayer(nn.Module): class TransformerSentenceEncoderLayer(nn.Module):
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
from typing import Callable
import copy import copy
import importlib.util import importlib.util
import logging import logging
...@@ -19,6 +20,8 @@ import torch ...@@ -19,6 +20,8 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.serialization import default_restore_location from torch.serialization import default_restore_location
from fairseq.modules import gelu, gelu_fast
def torch_persistent_save(*args, **kwargs): def torch_persistent_save(*args, **kwargs):
for i in range(3): for i in range(3):
...@@ -462,3 +465,15 @@ def log_softmax(x, dim, onnx_trace=False): ...@@ -462,3 +465,15 @@ def log_softmax(x, dim, onnx_trace=False):
def deprecation_warning(message, stacklevel=3): def deprecation_warning(message, stacklevel=3):
# don't use DeprecationWarning, since it's ignored by default # don't use DeprecationWarning, since it's ignored by default
warnings.warn(message, stacklevel=stacklevel) warnings.warn(message, stacklevel=stacklevel)
def get_activation_fn(activation: str) -> Callable:
""" Returns the activation function corresponding to `activation` """
if activation == 'relu':
return F.relu
elif activation == 'gelu':
return gelu
elif activation == 'gelu_fast':
return gelu_fast
else:
raise RuntimeError(f"--activation-fn {activation} not supported")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment