Commit 5408bc08 authored by Matt Le's avatar Matt Le Committed by Facebook Github Bot
Browse files

Fix loading XLM pretraining

Summary: We never actually load the model parameters from an XLM model when using tranformer_from_pretrained_xlm.  Also, change encoder_learned_pos from True -> False

Reviewed By: liezl200

Differential Revision: D15629061

fbshipit-source-id: 759eadc88041eae94505477960de57dd78a99dcb
parent 0d636744
...@@ -337,7 +337,7 @@ def xlm_architecture(args): ...@@ -337,7 +337,7 @@ def xlm_architecture(args):
args, 'share_encoder_input_output_embed', True) args, 'share_encoder_input_output_embed', True)
args.no_token_positional_embeddings = getattr( args.no_token_positional_embeddings = getattr(
args, 'no_token_positional_embeddings', False) args, 'no_token_positional_embeddings', False)
args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', True) args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', False)
args.num_segment = getattr(args, 'num_segment', 1) args.num_segment = getattr(args, 'num_segment', 1)
args.encoder_layers = getattr(args, 'encoder_layers', 6) args.encoder_layers = getattr(args, 'encoder_layers', 6)
......
...@@ -153,10 +153,18 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -153,10 +153,18 @@ class TransformerModel(FairseqEncoderDecoderModel):
tgt_dict, args.decoder_embed_dim, args.decoder_embed_path tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
) )
encoder = TransformerEncoder(args, src_dict, encoder_embed_tokens) encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
decoder = TransformerDecoder(args, tgt_dict, decoder_embed_tokens) decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
return TransformerModel(encoder, decoder) return TransformerModel(encoder, decoder)
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerEncoder(args, src_dict, embed_tokens)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
return TransformerDecoder(args, tgt_dict, embed_tokens)
class TransformerEncoder(FairseqEncoder): class TransformerEncoder(FairseqEncoder):
""" """
......
...@@ -54,7 +54,7 @@ class SinusoidalPositionalEmbedding(nn.Module): ...@@ -54,7 +54,7 @@ class SinusoidalPositionalEmbedding(nn.Module):
emb[padding_idx, :] = 0 emb[padding_idx, :] = 0
return emb return emb
def forward(self, input, incremental_state=None, timestep=None): def forward(self, input, incremental_state=None, timestep=None, **kwargs):
"""Input is expected to be of size [bsz x seqlen].""" """Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = torch.onnx.operators.shape_as_tensor(input) bsz, seq_len = torch.onnx.operators.shape_as_tensor(input)
max_pos = self.padding_idx + 1 + seq_len max_pos = self.padding_idx + 1 + seq_len
......
...@@ -264,12 +264,16 @@ class TestMaskedLanguageModel(unittest.TestCase): ...@@ -264,12 +264,16 @@ class TestMaskedLanguageModel(unittest.TestCase):
preprocess_lm_data(data_dir) preprocess_lm_data(data_dir)
train_masked_language_model(data_dir, "xlm_base") train_masked_language_model(data_dir, "xlm_base")
def test_pretrained_masked_lm_for_translation(self): def _test_pretrained_masked_lm_for_translation(self, learned_pos_emb, encoder_only):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory("test_mlm") as data_dir: with tempfile.TemporaryDirectory("test_mlm") as data_dir:
create_dummy_data(data_dir) create_dummy_data(data_dir)
preprocess_lm_data(data_dir) preprocess_lm_data(data_dir)
train_masked_language_model(data_dir, arch="xlm_base") train_masked_language_model(
data_dir,
arch="xlm_base",
extra_args=('--encoder-learned-pos',) if learned_pos_emb else ()
)
with tempfile.TemporaryDirectory( with tempfile.TemporaryDirectory(
"test_mlm_translation" "test_mlm_translation"
) as translation_dir: ) as translation_dir:
...@@ -300,68 +304,29 @@ class TestMaskedLanguageModel(unittest.TestCase): ...@@ -300,68 +304,29 @@ class TestMaskedLanguageModel(unittest.TestCase):
"32", "32",
"--pretrained-xlm-checkpoint", "--pretrained-xlm-checkpoint",
f"{data_dir}/checkpoint_last.pt", f"{data_dir}/checkpoint_last.pt",
"--encoder-learned-pos",
"--decoder-learned-pos",
"--activation-fn", "--activation-fn",
"gelu", "gelu",
"--max-source-positions", "--max-source-positions",
"500", "500",
"--max-target-positions", "--max-target-positions",
"500", "500",
], ] + (
["--encoder-learned-pos", "--decoder-learned-pos"]
if learned_pos_emb else []
) + (['--init-encoder-only'] if encoder_only else []),
task="translation_from_pretrained_xlm", task="translation_from_pretrained_xlm",
) )
def test_pretrained_masked_lm_for_translation_learned_pos_emb(self):
self._test_pretrained_masked_lm_for_translation(True, False)
def test_pretrained_masked_lm_for_translation_sinusoidal_pos_emb(self):
self._test_pretrained_masked_lm_for_translation(False, False)
def test_pretrained_masked_lm_for_translation_encoder_only(self): def test_pretrained_masked_lm_for_translation_encoder_only(self):
with contextlib.redirect_stdout(StringIO()): self._test_pretrained_masked_lm_for_translation(True, True)
with tempfile.TemporaryDirectory("test_mlm") as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_masked_language_model(data_dir, arch="xlm_base")
with tempfile.TemporaryDirectory(
"test_mlm_translation"
) as translation_dir:
create_dummy_data(translation_dir)
preprocess_translation_data(
translation_dir, extra_flags=["--joined-dictionary"]
)
# Train transformer with data_dir/checkpoint_last.pt
train_translation_model(
translation_dir,
arch="transformer_from_pretrained_xlm",
extra_flags=[
"--decoder-layers",
"1",
"--decoder-embed-dim",
"32",
"--decoder-attention-heads",
"1",
"--decoder-ffn-embed-dim",
"32",
"--encoder-layers",
"1",
"--encoder-embed-dim",
"32",
"--encoder-attention-heads",
"1",
"--encoder-ffn-embed-dim",
"32",
"--pretrained-xlm-checkpoint",
f"{data_dir}/checkpoint_last.pt",
"--encoder-learned-pos",
"--decoder-learned-pos",
"--activation-fn",
"gelu",
"--max-source-positions",
"500",
"--max-target-positions",
"500",
"--init-encoder-only",
],
task="translation_from_pretrained_xlm",
)
def train_masked_language_model(data_dir, arch): def train_masked_language_model(data_dir, arch, extra_args=()):
train_parser = options.get_training_parser() train_parser = options.get_training_parser()
# TODO: langs should be in and out right? # TODO: langs should be in and out right?
train_args = options.parse_args_and_arch( train_args = options.parse_args_and_arch(
...@@ -419,7 +384,7 @@ def train_masked_language_model(data_dir, arch): ...@@ -419,7 +384,7 @@ def train_masked_language_model(data_dir, arch):
"1", "1",
"--dataset-impl", "--dataset-impl",
"raw", "raw",
], ] + list(extra_args),
) )
train.main(train_args) train.main(train_args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment