Commit 8f058ea0 authored by Sergey Edunov's avatar Sergey Edunov
Browse files

Don't generate during training, add --quiet to generate.py

parent a8260d52
...@@ -17,7 +17,7 @@ from fairseq.dictionary import Dictionary ...@@ -17,7 +17,7 @@ from fairseq.dictionary import Dictionary
from fairseq.indexed_dataset import IndexedDataset, IndexedInMemoryDataset from fairseq.indexed_dataset import IndexedDataset, IndexedInMemoryDataset
def load_with_check(path, src=None, dst=None): def load_with_check(path, load_splits, src=None, dst=None):
"""Loads the train, valid, and test sets from the specified folder """Loads the train, valid, and test sets from the specified folder
and check that training files exist.""" and check that training files exist."""
...@@ -43,12 +43,12 @@ def load_with_check(path, src=None, dst=None): ...@@ -43,12 +43,12 @@ def load_with_check(path, src=None, dst=None):
else: else:
raise ValueError('training file not found for {}-{}'.format(src, dst)) raise ValueError('training file not found for {}-{}'.format(src, dst))
dataset = load(path, src, dst) dataset = load(path, load_splits, src, dst)
return dataset return dataset
def load(path, src, dst): def load(path, load_splits, src, dst):
"""Loads the train, valid, and test sets from the specified folder.""" """Loads specified data splits (e.g. test, train or valid) from the path."""
langcode = '{}-{}'.format(src, dst) langcode = '{}-{}'.format(src, dst)
...@@ -59,7 +59,7 @@ def load(path, src, dst): ...@@ -59,7 +59,7 @@ def load(path, src, dst):
dst_dict = Dictionary.load(fmt_path('dict.{}.txt', dst)) dst_dict = Dictionary.load(fmt_path('dict.{}.txt', dst))
dataset = LanguageDatasets(src, dst, src_dict, dst_dict) dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
for split in ['train', 'valid', 'test']: for split in load_splits:
for k in itertools.count(): for k in itertools.count():
prefix = "{}{}".format(split, k if k > 0 else '') prefix = "{}{}".format(split, k if k > 0 else '')
src_path = fmt_path('{}.{}.{}', prefix, langcode, src) src_path = fmt_path('{}.{}.{}', prefix, langcode, src)
......
...@@ -104,6 +104,8 @@ def add_generation_args(parser): ...@@ -104,6 +104,8 @@ def add_generation_args(parser):
help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences') help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences')
group.add_argument('--unk-replace-dict', default='', type=str, group.add_argument('--unk-replace-dict', default='', type=str,
help='performs unk word replacement') help='performs unk word replacement')
group.add_argument('--quiet', action='store_true',
help='Only print final scores')
return group return group
......
...@@ -94,7 +94,7 @@ def load_checkpoint(filename, model, optimizer, lr_scheduler, cuda_device=None): ...@@ -94,7 +94,7 @@ def load_checkpoint(filename, model, optimizer, lr_scheduler, cuda_device=None):
return epoch, batch_offset return epoch, batch_offset
def load_ensemble_for_inference(filenames, data_path): def load_ensemble_for_inference(filenames, data_path, split):
# load model architectures and weights # load model architectures and weights
states = [] states = []
for filename in filenames: for filename in filenames:
...@@ -106,7 +106,7 @@ def load_ensemble_for_inference(filenames, data_path): ...@@ -106,7 +106,7 @@ def load_ensemble_for_inference(filenames, data_path):
# load dataset # load dataset
args = states[0]['args'] args = states[0]['args']
dataset = data.load(data_path, args.source_lang, args.target_lang) dataset = data.load(data_path, [split], args.source_lang, args.target_lang)
# build models # build models
ensemble = [] ensemble = []
......
...@@ -38,7 +38,7 @@ def main(): ...@@ -38,7 +38,7 @@ def main():
# Load model and dataset # Load model and dataset
print('| loading model(s) from {}'.format(', '.join(args.path))) print('| loading model(s) from {}'.format(', '.join(args.path)))
models, dataset = utils.load_ensemble_for_inference(args.path, args.data) models, dataset = utils.load_ensemble_for_inference(args.path, args.data, args.gen_subset)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
...@@ -81,6 +81,8 @@ def main(): ...@@ -81,6 +81,8 @@ def main():
bpe_symbol = '@@ ' if args.remove_bpe else None bpe_symbol = '@@ ' if args.remove_bpe else None
def display_hypotheses(id, src, orig, ref, hypos): def display_hypotheses(id, src, orig, ref, hypos):
if args.quiet:
return
id_str = '' if id is None else '-{}'.format(id) id_str = '' if id is None else '-{}'.format(id)
src_str = to_sentence(dataset.src_dict, src, bpe_symbol) src_str = to_sentence(dataset.src_dict, src, bpe_symbol)
print('S{}\t{}'.format(id_str, src_str)) print('S{}\t{}'.format(id_str, src_str))
......
...@@ -29,9 +29,6 @@ def main(): ...@@ -29,9 +29,6 @@ def main():
dataset_args.add_argument('--valid-subset', default='valid', metavar='SPLIT', dataset_args.add_argument('--valid-subset', default='valid', metavar='SPLIT',
help='comma separated list ofdata subsets ' help='comma separated list ofdata subsets '
' to use for validation (train, valid, valid1,test, test1)') ' to use for validation (train, valid, valid1,test, test1)')
dataset_args.add_argument('--test-subset', default='test', metavar='SPLIT',
help='comma separated list ofdata subset '
'to use for testing (train, valid, test)')
options.add_optimization_args(parser) options.add_optimization_args(parser)
options.add_checkpoint_args(parser) options.add_checkpoint_args(parser)
options.add_model_args(parser) options.add_model_args(parser)
...@@ -48,7 +45,7 @@ def main(): ...@@ -48,7 +45,7 @@ def main():
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
# Load dataset # Load dataset
dataset = data.load_with_check(args.data, args.source_lang, args.target_lang) dataset = data.load_with_check(args.data, ['train', 'valid'], args.source_lang, args.target_lang)
if args.source_lang is None or args.target_lang is None: if args.source_lang is None or args.target_lang is None:
# record inferred languages in args, so that it's saved in checkpoints # record inferred languages in args, so that it's saved in checkpoints
args.source_lang, args.target_lang = dataset.src, dataset.dst args.source_lang, args.target_lang = dataset.src, dataset.dst
...@@ -100,13 +97,6 @@ def main(): ...@@ -100,13 +97,6 @@ def main():
train_meter.stop() train_meter.stop()
print('| done training in {:.1f} seconds'.format(train_meter.sum)) print('| done training in {:.1f} seconds'.format(train_meter.sum))
# Generate on test set and compute BLEU score
for beam in [1, 5, 10, 20]:
for subset in args.test_subset.split(','):
scorer = score_test(args, trainer.get_model(), dataset, subset, beam,
cuda_device=(0 if num_gpus > 0 else None))
print('| Test on {} with beam={}: {}'.format(subset, beam, scorer.result_string()))
# Stop multiprocessing # Stop multiprocessing
trainer.stop() trainer.stop()
...@@ -192,19 +182,5 @@ def validate(args, epoch, trainer, criterion, dataset, subset, ngpus): ...@@ -192,19 +182,5 @@ def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
return val_loss return val_loss
def score_test(args, model, dataset, subset, beam, cuda_device):
"""Evaluate the model on the test set and return the BLEU scorer."""
translator = SequenceGenerator([model], dataset.dst_dict, beam_size=beam)
if torch.cuda.is_available():
translator.cuda()
scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
itr = dataset.dataloader(subset, batch_size=4, max_positions=args.max_positions)
for _, _, ref, hypos in translator.generate_batched_itr(itr, cuda_device=cuda_device):
scorer.add(ref.int().cpu(), hypos[0]['tokens'].int().cpu())
return scorer
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment