"git@developer.sourcefind.cn:OpenDAS/fairseq.git" did not exist on "f5e52c1928ccd9d596ec0fd82219b6a5082f5b35"
Commit ef62ec0a authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add missing LM options

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/596

Differential Revision: D15432359

Pulled By: myleott

fbshipit-source-id: ebfdf0031864c3c88357543c0202ba0bd65a7b90
parent d10fe896
...@@ -92,10 +92,11 @@ class MaskedLMModel(BaseFairseqModel): ...@@ -92,10 +92,11 @@ class MaskedLMModel(BaseFairseqModel):
help='use custom param initialization for BERT') help='use custom param initialization for BERT')
# misc params # misc params
parser.add_argument('--activation-fn', choices=['relu', 'gelu', 'gelu_accurate'], parser.add_argument('--activation-fn',
help='Which activation function to use') choices=utils.get_available_activation_fns(),
help='activation function to use')
parser.add_argument('--pooler-activation-fn', parser.add_argument('--pooler-activation-fn',
choices=['relu', 'gelu', 'gelu_accurate', 'tanh'], choices=utils.get_available_activation_fns(),
help='Which activation function to use for pooler layer.') help='Which activation function to use for pooler layer.')
parser.add_argument('--encoder-normalize-before', action='store_true', parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block') help='apply layernorm before each encoder block')
......
...@@ -53,8 +53,9 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -53,8 +53,9 @@ class TransformerModel(FairseqEncoderDecoderModel):
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
# fmt: off # fmt: off
parser.add_argument('--activation-fn', choices=['relu', 'gelu', 'gelu_accurate'], parser.add_argument('--activation-fn',
help='Which activation function to use') choices=utils.get_available_activation_fns(),
help='activation function to use')
parser.add_argument('--dropout', type=float, metavar='D', parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability') help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D', parser.add_argument('--attention-dropout', type=float, metavar='D',
...@@ -73,6 +74,8 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -73,6 +74,8 @@ class TransformerModel(FairseqEncoderDecoderModel):
help='num encoder attention heads') help='num encoder attention heads')
parser.add_argument('--encoder-normalize-before', action='store_true', parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block') help='apply layernorm before each encoder block')
parser.add_argument('--decoder-final-norm', default=False, action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--encoder-learned-pos', action='store_true', parser.add_argument('--encoder-learned-pos', action='store_true',
help='use learned positional embeddings in the encoder') help='use learned positional embeddings in the encoder')
parser.add_argument('--decoder-embed-path', type=str, metavar='STR', parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from fairseq import options from fairseq import options, utils
from fairseq.models import ( from fairseq.models import (
FairseqLanguageModel, FairseqLanguageModel,
register_model, register_model,
...@@ -30,6 +30,9 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -30,6 +30,9 @@ class TransformerLanguageModel(FairseqLanguageModel):
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
# fmt: off # fmt: off
parser.add_argument('--activation-fn',
choices=utils.get_available_activation_fns(),
help='activation function to use')
parser.add_argument('--dropout', default=0.1, type=float, metavar='D', parser.add_argument('--dropout', default=0.1, type=float, metavar='D',
help='dropout probability') help='dropout probability')
parser.add_argument('--attention-dropout', default=0., type=float, metavar='D', parser.add_argument('--attention-dropout', default=0., type=float, metavar='D',
...@@ -50,6 +53,8 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -50,6 +53,8 @@ class TransformerLanguageModel(FairseqLanguageModel):
help='num decoder attention heads') help='num decoder attention heads')
parser.add_argument('--decoder-normalize-before', default=False, action='store_true', parser.add_argument('--decoder-normalize-before', default=False, action='store_true',
help='apply layernorm before each decoder block') help='apply layernorm before each decoder block')
parser.add_argument('--decoder-final-norm', default=False, action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. ' help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion') 'Must be used with adaptive_loss criterion')
...@@ -123,7 +128,8 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -123,7 +128,8 @@ class TransformerLanguageModel(FairseqLanguageModel):
assert args.decoder_input_dim == args.decoder_output_dim assert args.decoder_input_dim == args.decoder_output_dim
decoder = TransformerDecoder( decoder = TransformerDecoder(
args, task.output_dictionary, embed_tokens, no_encoder_attn=True, final_norm=False, args, task.output_dictionary, embed_tokens, no_encoder_attn=True,
final_norm=args.decoder_final_norm,
) )
return TransformerLanguageModel(decoder) return TransformerLanguageModel(decoder)
......
...@@ -6,12 +6,12 @@ ...@@ -6,12 +6,12 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from collections import defaultdict from collections import defaultdict
from typing import Callable
import copy import copy
import importlib.util import importlib.util
import math import math
import os import os
import sys import sys
from typing import Callable, List
import warnings import warnings
import torch import torch
...@@ -314,3 +314,13 @@ def get_activation_fn(activation: str) -> Callable: ...@@ -314,3 +314,13 @@ def get_activation_fn(activation: str) -> Callable:
return F.tanh return F.tanh
else: else:
raise RuntimeError(f"--activation-fn {activation} not supported") raise RuntimeError(f"--activation-fn {activation} not supported")
def get_available_activation_fns() -> List:
return [
'relu',
'gelu',
'gelu_fast', # deprecated
'gelu_accurate',
'tanh',
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment