Commit 1c565940 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix lightconv_lm and add test (#932)

Summary:
Fixes https://github.com/fairinternal/fairseq-py/issues/536
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/932

Differential Revision: D18783032

Pulled By: myleott

fbshipit-source-id: a520faccc20be78296a228214923ee1495fb536f
parent 5be1cf30
......@@ -339,7 +339,7 @@ class LightConvDecoder(FairseqIncrementalDecoder):
if self.normalize:
self.layer_norm = LayerNorm(embed_dim)
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
......
......@@ -101,9 +101,9 @@ class LightConvLanguageModel(FairseqLanguageModel):
# make sure all arguments are present in older models
base_lm_architecture(args)
if not hasattr(args, 'max_source_positions'):
if getattr(args, 'max_source_positions', None) is None:
args.max_source_positions = args.tokens_per_sample
if not hasattr(args, 'max_target_positions'):
if getattr(args, 'max_target_positions', None) is None:
args.max_target_positions = args.tokens_per_sample
if args.character_embeddings:
......@@ -145,6 +145,7 @@ def base_lm_architecture(args):
args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)
args.decoder_conv_dim = getattr(args, 'decoder_conv_dim', args.decoder_embed_dim)
# The model training is not stable without this
args.decoder_normalize_before = True
......@@ -159,6 +160,10 @@ def base_lm_architecture(args):
args.decoder_kernel_size_list = getattr(args, 'decoder_kernel_size_list', [3, 7, 15, 31, 31, 31])
if len(args.decoder_kernel_size_list) == 1:
args.decoder_kernel_size_list = args.decoder_kernel_size_list * args.decoder_layers
assert len(args.decoder_kernel_size_list) == args.decoder_layers, "decoder_kernel_size_list doesn't match decoder_layers"
args.decoder_glu = getattr(args, 'decoder_glu', True)
args.input_dropout = getattr(args, 'input_dropout', 0.1)
args.weight_dropout = getattr(args, 'weight_dropout', args.attention_dropout)
@register_model_architecture('lightconv_lm', 'lightconv_lm_gbw')
......
......@@ -175,9 +175,9 @@ class TransformerModel(FairseqEncoderDecoderModel):
if args.decoder_layers_to_keep:
args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
if not hasattr(args, 'max_source_positions'):
if getattr(args, 'max_source_positions', None) is None:
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
if not hasattr(args, 'max_target_positions'):
if getattr(args, 'max_target_positions', None) is None:
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
......
......@@ -390,6 +390,21 @@ class TestLanguageModeling(unittest.TestCase):
'--tokens-per-sample', '500',
])
def test_lightconv_lm(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_lightconv_lm') as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_language_model(
data_dir, 'lightconv_lm', ['--add-bos-token'], run_validation=True,
)
eval_lm_main(data_dir)
generate_main(data_dir, [
'--task', 'language_modeling',
'--sample-break-mode', 'eos',
'--tokens-per-sample', '500',
])
class TestMaskedLanguageModel(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment