Commit 5408bc08 authored by Matt Le's avatar Matt Le Committed by Facebook Github Bot
Browse files

Fix loading XLM pretraining

Summary: We never actually load the model parameters from an XLM model when using tranformer_from_pretrained_xlm.  Also, change encoder_learned_pos from True -> False

Reviewed By: liezl200

Differential Revision: D15629061

fbshipit-source-id: 759eadc88041eae94505477960de57dd78a99dcb
parent 0d636744
......@@ -337,7 +337,7 @@ def xlm_architecture(args):
args, 'share_encoder_input_output_embed', True)
args.no_token_positional_embeddings = getattr(
args, 'no_token_positional_embeddings', False)
args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', True)
args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', False)
args.num_segment = getattr(args, 'num_segment', 1)
args.encoder_layers = getattr(args, 'encoder_layers', 6)
......
......@@ -153,10 +153,18 @@ class TransformerModel(FairseqEncoderDecoderModel):
tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
)
encoder = TransformerEncoder(args, src_dict, encoder_embed_tokens)
decoder = TransformerDecoder(args, tgt_dict, decoder_embed_tokens)
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
return TransformerModel(encoder, decoder)
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerEncoder(args, src_dict, embed_tokens)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
return TransformerDecoder(args, tgt_dict, embed_tokens)
class TransformerEncoder(FairseqEncoder):
"""
......
......@@ -54,7 +54,7 @@ class SinusoidalPositionalEmbedding(nn.Module):
emb[padding_idx, :] = 0
return emb
def forward(self, input, incremental_state=None, timestep=None):
def forward(self, input, incremental_state=None, timestep=None, **kwargs):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = torch.onnx.operators.shape_as_tensor(input)
max_pos = self.padding_idx + 1 + seq_len
......
......@@ -264,12 +264,16 @@ class TestMaskedLanguageModel(unittest.TestCase):
preprocess_lm_data(data_dir)
train_masked_language_model(data_dir, "xlm_base")
def test_pretrained_masked_lm_for_translation(self):
def _test_pretrained_masked_lm_for_translation(self, learned_pos_emb, encoder_only):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory("test_mlm") as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_masked_language_model(data_dir, arch="xlm_base")
train_masked_language_model(
data_dir,
arch="xlm_base",
extra_args=('--encoder-learned-pos',) if learned_pos_emb else ()
)
with tempfile.TemporaryDirectory(
"test_mlm_translation"
) as translation_dir:
......@@ -300,68 +304,29 @@ class TestMaskedLanguageModel(unittest.TestCase):
"32",
"--pretrained-xlm-checkpoint",
f"{data_dir}/checkpoint_last.pt",
"--encoder-learned-pos",
"--decoder-learned-pos",
"--activation-fn",
"gelu",
"--max-source-positions",
"500",
"--max-target-positions",
"500",
],
] + (
["--encoder-learned-pos", "--decoder-learned-pos"]
if learned_pos_emb else []
) + (['--init-encoder-only'] if encoder_only else []),
task="translation_from_pretrained_xlm",
)
def test_pretrained_masked_lm_for_translation_learned_pos_emb(self):
self._test_pretrained_masked_lm_for_translation(True, False)
def test_pretrained_masked_lm_for_translation_sinusoidal_pos_emb(self):
self._test_pretrained_masked_lm_for_translation(False, False)
def test_pretrained_masked_lm_for_translation_encoder_only(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory("test_mlm") as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_masked_language_model(data_dir, arch="xlm_base")
with tempfile.TemporaryDirectory(
"test_mlm_translation"
) as translation_dir:
create_dummy_data(translation_dir)
preprocess_translation_data(
translation_dir, extra_flags=["--joined-dictionary"]
)
# Train transformer with data_dir/checkpoint_last.pt
train_translation_model(
translation_dir,
arch="transformer_from_pretrained_xlm",
extra_flags=[
"--decoder-layers",
"1",
"--decoder-embed-dim",
"32",
"--decoder-attention-heads",
"1",
"--decoder-ffn-embed-dim",
"32",
"--encoder-layers",
"1",
"--encoder-embed-dim",
"32",
"--encoder-attention-heads",
"1",
"--encoder-ffn-embed-dim",
"32",
"--pretrained-xlm-checkpoint",
f"{data_dir}/checkpoint_last.pt",
"--encoder-learned-pos",
"--decoder-learned-pos",
"--activation-fn",
"gelu",
"--max-source-positions",
"500",
"--max-target-positions",
"500",
"--init-encoder-only",
],
task="translation_from_pretrained_xlm",
)
self._test_pretrained_masked_lm_for_translation(True, True)
def train_masked_language_model(data_dir, arch):
def train_masked_language_model(data_dir, arch, extra_args=()):
train_parser = options.get_training_parser()
# TODO: langs should be in and out right?
train_args = options.parse_args_and_arch(
......@@ -419,7 +384,7 @@ def train_masked_language_model(data_dir, arch):
"1",
"--dataset-impl",
"raw",
],
] + list(extra_args),
)
train.main(train_args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment