Commit 7c89e13f authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix tests

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/822

Differential Revision: D16800078

Pulled By: myleott

fbshipit-source-id: b86e08e01f2fe13c64b77f1d23a5f6800f252bf7
parent baa8ce11
......@@ -172,7 +172,10 @@ def check_encoder_output(encoder_output, batch_size=None):
"encoder_padding_mask must be a torch.Tensor" + _current_postion_info()
)
return False, msg
if mask.dtype != torch.uint8:
if (
mask.dtype != torch.uint8
and (not hasattr(torch, 'bool') or mask.dtype != torch.bool)
):
msg = (
"encoder_padding_mask must have dtype of uint8"
+ _current_postion_info()
......
......@@ -151,7 +151,7 @@ class TestTranslation(unittest.TestCase):
'--decoder-layers', '2',
'--encoder-embed-dim', '8',
'--decoder-embed-dim', '8',
])
], run_validation=True)
generate_main(data_dir)
def test_lightconv(self):
......@@ -257,7 +257,9 @@ class TestLanguageModeling(unittest.TestCase):
with tempfile.TemporaryDirectory('test_transformer_lm') as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_language_model(data_dir, 'transformer_lm', ['--add-bos-token'])
train_language_model(
data_dir, 'transformer_lm', ['--add-bos-token'], run_validation=True,
)
eval_lm_main(data_dir)
......@@ -457,7 +459,7 @@ def preprocess_translation_data(data_dir, extra_flags=None):
preprocess.main(preprocess_args)
def train_translation_model(data_dir, arch, extra_flags=None, task='translation'):
def train_translation_model(data_dir, arch, extra_flags=None, task='translation', run_validation=False):
train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch(
train_parser,
......@@ -477,6 +479,7 @@ def train_translation_model(data_dir, arch, extra_flags=None, task='translation'
)
train.main(train_args)
if run_validation:
# test validation
validate_parser = options.get_validation_parser()
validate_args = options.parse_args_and_arch(
......@@ -534,7 +537,7 @@ def preprocess_lm_data(data_dir):
preprocess.main(preprocess_args)
def train_language_model(data_dir, arch, extra_flags=None):
def train_language_model(data_dir, arch, extra_flags=None, run_validation=False):
train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch(
train_parser,
......@@ -557,6 +560,7 @@ def train_language_model(data_dir, arch, extra_flags=None):
)
train.main(train_args)
if run_validation:
# test validation
validate_parser = options.get_validation_parser()
validate_args = options.parse_args_and_arch(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment