Commit 1c565940 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix lightconv_lm and add test (#932)

Summary:
Fixes https://github.com/fairinternal/fairseq-py/issues/536
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/932

Differential Revision: D18783032

Pulled By: myleott

fbshipit-source-id: a520faccc20be78296a228214923ee1495fb536f
parent 5be1cf30
...@@ -339,7 +339,7 @@ class LightConvDecoder(FairseqIncrementalDecoder): ...@@ -339,7 +339,7 @@ class LightConvDecoder(FairseqIncrementalDecoder):
if self.normalize: if self.normalize:
self.layer_norm = LayerNorm(embed_dim) self.layer_norm = LayerNorm(embed_dim)
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
""" """
Args: Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape prev_output_tokens (LongTensor): previous decoder outputs of shape
......
...@@ -101,9 +101,9 @@ class LightConvLanguageModel(FairseqLanguageModel): ...@@ -101,9 +101,9 @@ class LightConvLanguageModel(FairseqLanguageModel):
# make sure all arguments are present in older models # make sure all arguments are present in older models
base_lm_architecture(args) base_lm_architecture(args)
if not hasattr(args, 'max_source_positions'): if getattr(args, 'max_source_positions', None) is None:
args.max_source_positions = args.tokens_per_sample args.max_source_positions = args.tokens_per_sample
if not hasattr(args, 'max_target_positions'): if getattr(args, 'max_target_positions', None) is None:
args.max_target_positions = args.tokens_per_sample args.max_target_positions = args.tokens_per_sample
if args.character_embeddings: if args.character_embeddings:
...@@ -145,6 +145,7 @@ def base_lm_architecture(args): ...@@ -145,6 +145,7 @@ def base_lm_architecture(args):
args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim) args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim) args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)
args.decoder_conv_dim = getattr(args, 'decoder_conv_dim', args.decoder_embed_dim)
# The model training is not stable without this # The model training is not stable without this
args.decoder_normalize_before = True args.decoder_normalize_before = True
...@@ -159,6 +160,10 @@ def base_lm_architecture(args): ...@@ -159,6 +160,10 @@ def base_lm_architecture(args):
args.decoder_kernel_size_list = getattr(args, 'decoder_kernel_size_list', [3, 7, 15, 31, 31, 31]) args.decoder_kernel_size_list = getattr(args, 'decoder_kernel_size_list', [3, 7, 15, 31, 31, 31])
if len(args.decoder_kernel_size_list) == 1: if len(args.decoder_kernel_size_list) == 1:
args.decoder_kernel_size_list = args.decoder_kernel_size_list * args.decoder_layers args.decoder_kernel_size_list = args.decoder_kernel_size_list * args.decoder_layers
assert len(args.decoder_kernel_size_list) == args.decoder_layers, "decoder_kernel_size_list doesn't match decoder_layers"
args.decoder_glu = getattr(args, 'decoder_glu', True)
args.input_dropout = getattr(args, 'input_dropout', 0.1)
args.weight_dropout = getattr(args, 'weight_dropout', args.attention_dropout)
@register_model_architecture('lightconv_lm', 'lightconv_lm_gbw') @register_model_architecture('lightconv_lm', 'lightconv_lm_gbw')
......
...@@ -175,9 +175,9 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -175,9 +175,9 @@ class TransformerModel(FairseqEncoderDecoderModel):
if args.decoder_layers_to_keep: if args.decoder_layers_to_keep:
args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
if not hasattr(args, 'max_source_positions'): if getattr(args, 'max_source_positions', None) is None:
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
if not hasattr(args, 'max_target_positions'): if getattr(args, 'max_target_positions', None) is None:
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
......
...@@ -390,6 +390,21 @@ class TestLanguageModeling(unittest.TestCase): ...@@ -390,6 +390,21 @@ class TestLanguageModeling(unittest.TestCase):
'--tokens-per-sample', '500', '--tokens-per-sample', '500',
]) ])
def test_lightconv_lm(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_lightconv_lm') as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_language_model(
data_dir, 'lightconv_lm', ['--add-bos-token'], run_validation=True,
)
eval_lm_main(data_dir)
generate_main(data_dir, [
'--task', 'language_modeling',
'--sample-break-mode', 'eos',
'--tokens-per-sample', '500',
])
class TestMaskedLanguageModel(unittest.TestCase): class TestMaskedLanguageModel(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment