"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "eac3dc7bab436725b0ba65e556d3a6ffd43c24e1"
Commit fa7791df authored by Matt Le's avatar Matt Le Committed by Facebook Github Bot
Browse files

Change encoder_learned_pos default back to True for xlm_base

Reviewed By: pipibjc

Differential Revision: D15635402

fbshipit-source-id: e92fab914de40775d7bad851420355240d822bde
parent 5408bc08
...@@ -337,7 +337,7 @@ def xlm_architecture(args): ...@@ -337,7 +337,7 @@ def xlm_architecture(args):
args, 'share_encoder_input_output_embed', True) args, 'share_encoder_input_output_embed', True)
args.no_token_positional_embeddings = getattr( args.no_token_positional_embeddings = getattr(
args, 'no_token_positional_embeddings', False) args, 'no_token_positional_embeddings', False)
args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', False) args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', True)
args.num_segment = getattr(args, 'num_segment', 1) args.num_segment = getattr(args, 'num_segment', 1)
args.encoder_layers = getattr(args, 'encoder_layers', 6) args.encoder_layers = getattr(args, 'encoder_layers', 6)
......
...@@ -262,7 +262,7 @@ class TestMaskedLanguageModel(unittest.TestCase): ...@@ -262,7 +262,7 @@ class TestMaskedLanguageModel(unittest.TestCase):
with tempfile.TemporaryDirectory("test_mlm") as data_dir: with tempfile.TemporaryDirectory("test_mlm") as data_dir:
create_dummy_data(data_dir) create_dummy_data(data_dir)
preprocess_lm_data(data_dir) preprocess_lm_data(data_dir)
train_masked_language_model(data_dir, "xlm_base") train_masked_language_model(data_dir, "masked_lm")
def _test_pretrained_masked_lm_for_translation(self, learned_pos_emb, encoder_only): def _test_pretrained_masked_lm_for_translation(self, learned_pos_emb, encoder_only):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
...@@ -271,7 +271,7 @@ class TestMaskedLanguageModel(unittest.TestCase): ...@@ -271,7 +271,7 @@ class TestMaskedLanguageModel(unittest.TestCase):
preprocess_lm_data(data_dir) preprocess_lm_data(data_dir)
train_masked_language_model( train_masked_language_model(
data_dir, data_dir,
arch="xlm_base", arch="masked_lm",
extra_args=('--encoder-learned-pos',) if learned_pos_emb else () extra_args=('--encoder-learned-pos',) if learned_pos_emb else ()
) )
with tempfile.TemporaryDirectory( with tempfile.TemporaryDirectory(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment