Commit 5008fd4e authored by Liezl Puzon's avatar Liezl Puzon Committed by Facebook Github Bot
Browse files

XLM for NMT: option to only load encoder or decoder (#666)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/666

Option to load the XLM weights into only the encoder or the decoder

Reviewed By: pipibjc

Differential Revision: D14881004

fbshipit-source-id: 6d0d598ea9c445ec468f71b8e855712de89a5dac
parent 8da9b1c5
...@@ -32,6 +32,16 @@ class TransformerFromPretrainedXLMModel(TransformerModel): ...@@ -32,6 +32,16 @@ class TransformerFromPretrainedXLMModel(TransformerModel):
metavar="STR", metavar="STR",
help="XLM model to use for initializing transformer encoder and/or decoder", help="XLM model to use for initializing transformer encoder and/or decoder",
) )
parser.add_argument(
"--init-encoder-only",
action="store_true",
help="if set, don't load the XLM weights and embeddings into decoder",
)
parser.add_argument(
"--init-decoder-only",
action="store_true",
help="if set, don't load the XLM weights and embeddings into encoder",
)
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
...@@ -48,7 +58,10 @@ class TransformerFromPretrainedXLMModel(TransformerModel): ...@@ -48,7 +58,10 @@ class TransformerFromPretrainedXLMModel(TransformerModel):
"For translation, you may want to use --task " "For translation, you may want to use --task "
"translation_from_pretrained_xlm" "translation_from_pretrained_xlm"
) )
assert not (
getattr(args, "init_encoder_only", False)
and getattr(args, "init_decoder_only", False)
), "Only one of --init-encoder-only and --init-decoder-only can be set."
return super().build_model(args, task) return super().build_model(args, task)
@classmethod @classmethod
...@@ -100,6 +113,10 @@ def upgrade_state_dict_with_xlm_weights( ...@@ -100,6 +113,10 @@ def upgrade_state_dict_with_xlm_weights(
class TransformerEncoderFromPretrainedXLM(TransformerEncoder): class TransformerEncoderFromPretrainedXLM(TransformerEncoder):
def __init__(self, args, dictionary, embed_tokens): def __init__(self, args, dictionary, embed_tokens):
super().__init__(args, dictionary, embed_tokens) super().__init__(args, dictionary, embed_tokens)
if getattr(args, 'init_decoder_only', False):
# Don't load XLM weights for encoder if --init-decoder-only
return
assert hasattr(args, "pretrained_xlm_checkpoint"), ( assert hasattr(args, "pretrained_xlm_checkpoint"), (
"--pretrained-xlm-checkpoint must be specified to load Transformer " "--pretrained-xlm-checkpoint must be specified to load Transformer "
"encoder from pretrained XLM" "encoder from pretrained XLM"
...@@ -118,6 +135,9 @@ class TransformerDecoderFromPretrainedXLM(TransformerDecoder): ...@@ -118,6 +135,9 @@ class TransformerDecoderFromPretrainedXLM(TransformerDecoder):
super().__init__( super().__init__(
args, dictionary, embed_tokens, no_encoder_attn, final_norm args, dictionary, embed_tokens, no_encoder_attn, final_norm
) )
if getattr(args, 'init_encoder_only', False):
# Don't load XLM weights for decoder if --init-encoder-only
return
assert hasattr(args, "pretrained_xlm_checkpoint"), ( assert hasattr(args, "pretrained_xlm_checkpoint"), (
"--pretrained-xlm-checkpoint must be specified to load Transformer " "--pretrained-xlm-checkpoint must be specified to load Transformer "
"decoder from pretrained XLM" "decoder from pretrained XLM"
......
...@@ -276,6 +276,54 @@ class TestMaskedLanguageModel(unittest.TestCase): ...@@ -276,6 +276,54 @@ class TestMaskedLanguageModel(unittest.TestCase):
task="translation_from_pretrained_xlm", task="translation_from_pretrained_xlm",
) )
def test_pretrained_masked_lm_for_translation_encoder_only(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory("test_mlm") as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_masked_language_model(data_dir, arch="xlm_base")
with tempfile.TemporaryDirectory(
"test_mlm_translation"
) as translation_dir:
create_dummy_data(translation_dir)
preprocess_translation_data(
translation_dir, extra_flags=["--joined-dictionary"]
)
# Train transformer with data_dir/checkpoint_last.pt
train_translation_model(
translation_dir,
arch="transformer_from_pretrained_xlm",
extra_flags=[
"--decoder-layers",
"1",
"--decoder-embed-dim",
"32",
"--decoder-attention-heads",
"1",
"--decoder-ffn-embed-dim",
"32",
"--encoder-layers",
"1",
"--encoder-embed-dim",
"32",
"--encoder-attention-heads",
"1",
"--encoder-ffn-embed-dim",
"32",
"--pretrained-xlm-checkpoint",
f"{data_dir}/checkpoint_last.pt",
"--encoder-learned-pos",
"--decoder-learned-pos",
"--activation-fn",
"gelu",
"--max-source-positions",
"500",
"--max-target-positions",
"500",
"--init-encoder-only",
],
task="translation_from_pretrained_xlm",
)
def train_masked_language_model(data_dir, arch): def train_masked_language_model(data_dir, arch):
train_parser = options.get_training_parser() train_parser = options.get_training_parser()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment