Commit 7c89e13f authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix tests

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/822

Differential Revision: D16800078

Pulled By: myleott

fbshipit-source-id: b86e08e01f2fe13c64b77f1d23a5f6800f252bf7
parent baa8ce11
...@@ -172,7 +172,10 @@ def check_encoder_output(encoder_output, batch_size=None): ...@@ -172,7 +172,10 @@ def check_encoder_output(encoder_output, batch_size=None):
"encoder_padding_mask must be a torch.Tensor" + _current_postion_info() "encoder_padding_mask must be a torch.Tensor" + _current_postion_info()
) )
return False, msg return False, msg
if mask.dtype != torch.uint8: if (
mask.dtype != torch.uint8
and (not hasattr(torch, 'bool') or mask.dtype != torch.bool)
):
msg = ( msg = (
"encoder_padding_mask must have dtype of uint8" "encoder_padding_mask must have dtype of uint8"
+ _current_postion_info() + _current_postion_info()
......
...@@ -151,7 +151,7 @@ class TestTranslation(unittest.TestCase): ...@@ -151,7 +151,7 @@ class TestTranslation(unittest.TestCase):
'--decoder-layers', '2', '--decoder-layers', '2',
'--encoder-embed-dim', '8', '--encoder-embed-dim', '8',
'--decoder-embed-dim', '8', '--decoder-embed-dim', '8',
]) ], run_validation=True)
generate_main(data_dir) generate_main(data_dir)
def test_lightconv(self): def test_lightconv(self):
...@@ -257,7 +257,9 @@ class TestLanguageModeling(unittest.TestCase): ...@@ -257,7 +257,9 @@ class TestLanguageModeling(unittest.TestCase):
with tempfile.TemporaryDirectory('test_transformer_lm') as data_dir: with tempfile.TemporaryDirectory('test_transformer_lm') as data_dir:
create_dummy_data(data_dir) create_dummy_data(data_dir)
preprocess_lm_data(data_dir) preprocess_lm_data(data_dir)
train_language_model(data_dir, 'transformer_lm', ['--add-bos-token']) train_language_model(
data_dir, 'transformer_lm', ['--add-bos-token'], run_validation=True,
)
eval_lm_main(data_dir) eval_lm_main(data_dir)
...@@ -457,7 +459,7 @@ def preprocess_translation_data(data_dir, extra_flags=None): ...@@ -457,7 +459,7 @@ def preprocess_translation_data(data_dir, extra_flags=None):
preprocess.main(preprocess_args) preprocess.main(preprocess_args)
def train_translation_model(data_dir, arch, extra_flags=None, task='translation'): def train_translation_model(data_dir, arch, extra_flags=None, task='translation', run_validation=False):
train_parser = options.get_training_parser() train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch( train_args = options.parse_args_and_arch(
train_parser, train_parser,
...@@ -477,20 +479,21 @@ def train_translation_model(data_dir, arch, extra_flags=None, task='translation' ...@@ -477,20 +479,21 @@ def train_translation_model(data_dir, arch, extra_flags=None, task='translation'
) )
train.main(train_args) train.main(train_args)
# test validation if run_validation:
validate_parser = options.get_validation_parser() # test validation
validate_args = options.parse_args_and_arch( validate_parser = options.get_validation_parser()
validate_parser, validate_args = options.parse_args_and_arch(
[ validate_parser,
'--task', task, [
data_dir, '--task', task,
'--path', os.path.join(data_dir, 'checkpoint_last.pt'), data_dir,
'--valid-subset', 'valid', '--path', os.path.join(data_dir, 'checkpoint_last.pt'),
'--max-tokens', '500', '--valid-subset', 'valid',
'--no-progress-bar', '--max-tokens', '500',
] '--no-progress-bar',
) ]
validate.main(validate_args) )
validate.main(validate_args)
def generate_main(data_dir, extra_flags=None): def generate_main(data_dir, extra_flags=None):
...@@ -534,7 +537,7 @@ def preprocess_lm_data(data_dir): ...@@ -534,7 +537,7 @@ def preprocess_lm_data(data_dir):
preprocess.main(preprocess_args) preprocess.main(preprocess_args)
def train_language_model(data_dir, arch, extra_flags=None): def train_language_model(data_dir, arch, extra_flags=None, run_validation=False):
train_parser = options.get_training_parser() train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch( train_args = options.parse_args_and_arch(
train_parser, train_parser,
...@@ -557,20 +560,21 @@ def train_language_model(data_dir, arch, extra_flags=None): ...@@ -557,20 +560,21 @@ def train_language_model(data_dir, arch, extra_flags=None):
) )
train.main(train_args) train.main(train_args)
# test validation if run_validation:
validate_parser = options.get_validation_parser() # test validation
validate_args = options.parse_args_and_arch( validate_parser = options.get_validation_parser()
validate_parser, validate_args = options.parse_args_and_arch(
[ validate_parser,
'--task', 'language_modeling', [
data_dir, '--task', 'language_modeling',
'--path', os.path.join(data_dir, 'checkpoint_last.pt'), data_dir,
'--valid-subset', 'valid', '--path', os.path.join(data_dir, 'checkpoint_last.pt'),
'--max-tokens', '500', '--valid-subset', 'valid',
'--no-progress-bar', '--max-tokens', '500',
] '--no-progress-bar',
) ]
validate.main(validate_args) )
validate.main(validate_args)
def eval_lm_main(data_dir): def eval_lm_main(data_dir):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment