Commit ef62ec0a authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add missing LM options

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/596

Differential Revision: D15432359

Pulled By: myleott

fbshipit-source-id: ebfdf0031864c3c88357543c0202ba0bd65a7b90
parent d10fe896
......@@ -92,10 +92,11 @@ class MaskedLMModel(BaseFairseqModel):
help='use custom param initialization for BERT')
# misc params
parser.add_argument('--activation-fn', choices=['relu', 'gelu', 'gelu_accurate'],
help='Which activation function to use')
parser.add_argument('--activation-fn',
choices=utils.get_available_activation_fns(),
help='activation function to use')
parser.add_argument('--pooler-activation-fn',
choices=['relu', 'gelu', 'gelu_accurate', 'tanh'],
choices=utils.get_available_activation_fns(),
help='Which activation function to use for pooler layer.')
parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block')
......
......@@ -53,8 +53,9 @@ class TransformerModel(FairseqEncoderDecoderModel):
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--activation-fn', choices=['relu', 'gelu', 'gelu_accurate'],
help='Which activation function to use')
parser.add_argument('--activation-fn',
choices=utils.get_available_activation_fns(),
help='activation function to use')
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D',
......@@ -73,6 +74,8 @@ class TransformerModel(FairseqEncoderDecoderModel):
help='num encoder attention heads')
parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block')
parser.add_argument('--decoder-final-norm', default=False, action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--encoder-learned-pos', action='store_true',
help='use learned positional embeddings in the encoder')
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
......
......@@ -5,7 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq import options
from fairseq import options, utils
from fairseq.models import (
FairseqLanguageModel,
register_model,
......@@ -30,6 +30,9 @@ class TransformerLanguageModel(FairseqLanguageModel):
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--activation-fn',
choices=utils.get_available_activation_fns(),
help='activation function to use')
parser.add_argument('--dropout', default=0.1, type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', default=0., type=float, metavar='D',
......@@ -50,6 +53,8 @@ class TransformerLanguageModel(FairseqLanguageModel):
help='num decoder attention heads')
parser.add_argument('--decoder-normalize-before', default=False, action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--decoder-final-norm', default=False, action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion')
......@@ -123,7 +128,8 @@ class TransformerLanguageModel(FairseqLanguageModel):
assert args.decoder_input_dim == args.decoder_output_dim
decoder = TransformerDecoder(
args, task.output_dictionary, embed_tokens, no_encoder_attn=True, final_norm=False,
args, task.output_dictionary, embed_tokens, no_encoder_attn=True,
final_norm=args.decoder_final_norm,
)
return TransformerLanguageModel(decoder)
......
......@@ -6,12 +6,12 @@
# can be found in the PATENTS file in the same directory.
from collections import defaultdict
from typing import Callable
import copy
import importlib.util
import math
import os
import sys
from typing import Callable, List
import warnings
import torch
......@@ -314,3 +314,13 @@ def get_activation_fn(activation: str) -> Callable:
return F.tanh
else:
raise RuntimeError(f"--activation-fn {activation} not supported")
def get_available_activation_fns() -> List:
return [
'relu',
'gelu',
'gelu_fast', # deprecated
'gelu_accurate',
'tanh',
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment