Commit 22e535e2 authored by alexeib's avatar alexeib Committed by Facebook Github Bot
Browse files

Merge internal changes

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/296

Differential Revision: D10121830

Pulled By: alexeib

fbshipit-source-id: 1b73430bdfdcb20a9a6123abfca3472a0d307b3b
parent b87c5366
......@@ -45,10 +45,10 @@ def main(parsed_args):
print('| loading model(s) from {}'.format(parsed_args.path))
models, args = utils.load_ensemble_for_inference(parsed_args.path.split(':'), task)
args.__dict__.update(parsed_args.__dict__)
print(args)
task.args = args
for arg in vars(parsed_args).keys():
if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}:
setattr(args, arg, getattr(parsed_args, arg))
task = tasks.setup_task(args)
# Load dataset splits
task.load_dataset(args.gen_subset)
......
......@@ -71,7 +71,7 @@ class SequenceScorer(object):
avg_probs = probs
else:
avg_probs.add_(probs)
if attn is not None:
if attn is not None and torch.is_tensor(attn):
attn = attn.data
if avg_attn is None:
avg_attn = attn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment