Commit e4f51e18 authored by alexeib's avatar alexeib Committed by Myle Ott
Browse files

load args from model for eval_lm

parent 45082e48
......@@ -14,23 +14,28 @@ from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer
def main(args):
assert args.path is not None, '--path required for evaluation!'
def main(parsed_args):
assert parsed_args.path is not None, '--path required for evaluation!'
args.tokens_per_sample = getattr(args, 'tokens_per_sample', 1024)
print(parsed_args)
use_cuda = torch.cuda.is_available() and not parsed_args.cpu
task = tasks.setup_task(parsed_args)
# Load ensemble
print('| loading model(s) from {}'.format(parsed_args.path))
models, args = utils.load_ensemble_for_inference(parsed_args.path.split(':'), task)
args.__dict__.update(parsed_args.__dict__)
print(args)
use_cuda = torch.cuda.is_available() and not args.cpu
task.args = args
# Load dataset splits
task = tasks.setup_task(args)
task.load_dataset(args.gen_subset)
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
# Load ensemble
print('| loading model(s) from {}'.format(args.path))
models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task)
# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
for model in models:
model.make_generation_fast_()
......
......@@ -193,6 +193,8 @@ class TransformerLanguageModel(FairseqLanguageModel):
else:
embed_tokens = Embedding(len(task.dictionary), args.decoder_embed_dim, task.dictionary.pad())
print(args)
decoder = TransformerDecoder(args, task.dictionary, embed_tokens, no_encoder_attn=True)
return TransformerLanguageModel(decoder)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment