Commit 8500bdd0 authored by Liezl Puzon's avatar Liezl Puzon Committed by Facebook Github Bot
Browse files

Add gelu and gelu_fast as possible activation functions (#653)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/653

After this diff, you can train a transformer model with --activation-fn 'relu', 'gelu', or 'gelu_fast'

gelu_fast is the default implementation in https://github.com/hendrycks/GELUs/blob/master/mnist_fcn.py#L72-L77
gelu is the alternate implementation in https://github.com/hendrycks/GELUs/blob/master/mnist_fcn.py#L72-L77 and the default implementation in https://github.com/facebookresearch/XLM

Reviewed By: pipibjc

Differential Revision: D14966006

fbshipit-source-id: 94e95fb99bd548ba47cf23b4999265c7b6833fc1
parent d8d03745
......@@ -136,7 +136,7 @@ class MaskedLMEncoder(FairseqEncoder):
)
encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
use_bert_layer_norm = getattr(args, 'bert_layer_norm', False)
use_gelu = getattr(args, 'use_gelu', False)
use_gelu = getattr(args, 'gelu', False)
apply_bert_init = getattr(args, 'apply_bert_init', False)
self.sentence_encoder = TransformerSentenceEncoder(
......
......@@ -52,8 +52,8 @@ class TransformerModel(FairseqModel):
help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--relu-dropout', type=float, metavar='D',
help='dropout probability after ReLU in FFN')
parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
help='dropout probability after activation in FFN.')
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
......@@ -94,6 +94,8 @@ class TransformerModel(FairseqModel):
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--activation-fn', choices=['relu', 'gelu', 'gelu_fast'],
help='Which activation function to use')
# fmt: on
@classmethod
......@@ -583,7 +585,13 @@ class TransformerEncoderLayer(nn.Module):
)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = args.dropout
self.relu_dropout = args.relu_dropout
self.activation_fn = utils.get_activation_fn(
activation=getattr(args, 'activation_fn', 'relu')
)
self.activation_dropout = getattr(args, 'activation_dropout', 0)
if self.activation_dropout == 0:
# for backwards compatibility with models that use args.relu_dropout
self.activation_dropout = getattr(args, 'relu_dropout', 0)
self.normalize_before = args.encoder_normalize_before
self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
......@@ -627,8 +635,8 @@ class TransformerEncoderLayer(nn.Module):
residual = x
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
x = F.relu(self.fc1(x))
x = F.dropout(x, p=self.relu_dropout, training=self.training)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
......@@ -668,7 +676,13 @@ class TransformerDecoderLayer(nn.Module):
dropout=args.attention_dropout,
)
self.dropout = args.dropout
self.relu_dropout = args.relu_dropout
self.activation_fn = utils.get_activation_fn(
activation=getattr(args, 'activation_fn', 'relu')
)
self.activation_dropout = getattr(args, 'activation_dropout', 0)
if self.activation_dropout == 0:
# for backwards compatibility with models that use args.relu_dropout
self.activation_dropout = getattr(args, 'relu_dropout', 0)
self.normalize_before = args.decoder_normalize_before
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
......@@ -752,8 +766,8 @@ class TransformerDecoderLayer(nn.Module):
residual = x
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
x = F.relu(self.fc1(x))
x = F.dropout(x, p=self.relu_dropout, training=self.training)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
......@@ -814,6 +828,7 @@ def base_lm_architecture(args):
args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)
args.adaptive_softmax_factor = getattr(args, 'adaptive_softmax_factor', 4)
args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False)
args.activation_fn = getattr(args, 'activation_fn', 'relu')
args.character_embeddings = getattr(args, 'character_embeddings', False)
......@@ -851,7 +866,7 @@ def transformer_lm_wiki103(args):
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '20000,60000')
args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0.2)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.relu_dropout = getattr(args, 'relu_dropout', 0.1)
args.activation_dropout = getattr(args, 'activation_dropout', 0.1)
transformer_lm_big(args)
......@@ -880,7 +895,8 @@ def base_architecture(args):
args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', False)
args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False)
args.attention_dropout = getattr(args, 'attention_dropout', 0.)
args.relu_dropout = getattr(args, 'relu_dropout', 0.)
args.activation_dropout = getattr(args, 'activation_dropout', 0.)
args.activation_fn = getattr(args, 'activation_fn', 'relu')
args.dropout = getattr(args, 'dropout', 0.1)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)
......@@ -943,5 +959,5 @@ def transformer_wmt_en_de_big_t2t(args):
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', True)
args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', True)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.relu_dropout = getattr(args, 'relu_dropout', 0.1)
args.activation_dropout = getattr(args, 'activation_dropout', 0.1)
transformer_vaswani_wmt_en_de_big(args)
......@@ -13,6 +13,7 @@ from .character_token_embedder import CharacterTokenEmbedder
from .conv_tbc import ConvTBC
from .downsampled_multihead_attention import DownsampledMultiHeadAttention
from .dynamic_convolution import DynamicConv1dTBC
from .gelu import gelu, gelu_fast
from .grad_multiply import GradMultiply
from .highway import Highway
from .layer_norm import LayerNorm
......@@ -37,6 +38,8 @@ __all__ = [
'ConvTBC',
'DownsampledMultiHeadAttention',
'DynamicConv1dTBC',
'gelu',
'gelu_fast',
'GradMultiply',
'Highway',
'LayerNorm',
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch
"""
See "Gaussian Error Linear Units (GELUs)" by Dan Hendrycks and Kevin Gimpel with
the corresponding GitHub repo: https://github.com/hendrycks/GELUs
"""
def gelu_fast(x):
if not hasattr(gelu_fast, "_a"):
gelu_fast._a = math.sqrt(2 / math.pi)
return 0.5 * x * (1 + torch.tanh(gelu_fast._a * (x + 0.044715 * torch.pow(x, 3))))
def gelu(x: torch.Tensor) -> torch.Tensor:
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
......@@ -10,14 +10,7 @@ import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.modules import MultiheadAttention, BertLayerNorm
def gelu(x: torch.Tensor) -> torch.Tensor:
"""
Implementation of the gelu activation function.
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
from fairseq.modules import gelu, MultiheadAttention, BertLayerNorm
class TransformerSentenceEncoderLayer(nn.Module):
......
......@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory.
from collections import defaultdict, OrderedDict
from typing import Callable
import copy
import importlib.util
import logging
......@@ -19,6 +20,8 @@ import torch
import torch.nn.functional as F
from torch.serialization import default_restore_location
from fairseq.modules import gelu, gelu_fast
def torch_persistent_save(*args, **kwargs):
for i in range(3):
......@@ -462,3 +465,15 @@ def log_softmax(x, dim, onnx_trace=False):
def deprecation_warning(message, stacklevel=3):
# don't use DeprecationWarning, since it's ignored by default
warnings.warn(message, stacklevel=stacklevel)
def get_activation_fn(activation: str) -> Callable:
""" Returns the activation function corresponding to `activation` """
if activation == 'relu':
return F.relu
elif activation == 'gelu':
return gelu
elif activation == 'gelu_fast':
return gelu_fast
else:
raise RuntimeError(f"--activation-fn {activation} not supported")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment