Commit 22e535e2 authored by alexeib's avatar alexeib Committed by Facebook Github Bot
Browse files

Merge internal changes

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/296

Differential Revision: D10121830

Pulled By: alexeib

fbshipit-source-id: 1b73430bdfdcb20a9a6123abfca3472a0d307b3b
parent b87c5366
...@@ -45,10 +45,10 @@ def main(parsed_args): ...@@ -45,10 +45,10 @@ def main(parsed_args):
print('| loading model(s) from {}'.format(parsed_args.path)) print('| loading model(s) from {}'.format(parsed_args.path))
models, args = utils.load_ensemble_for_inference(parsed_args.path.split(':'), task) models, args = utils.load_ensemble_for_inference(parsed_args.path.split(':'), task)
args.__dict__.update(parsed_args.__dict__) for arg in vars(parsed_args).keys():
print(args) if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}:
setattr(args, arg, getattr(parsed_args, arg))
task.args = args task = tasks.setup_task(args)
# Load dataset splits # Load dataset splits
task.load_dataset(args.gen_subset) task.load_dataset(args.gen_subset)
......
...@@ -71,7 +71,7 @@ class SequenceScorer(object): ...@@ -71,7 +71,7 @@ class SequenceScorer(object):
avg_probs = probs avg_probs = probs
else: else:
avg_probs.add_(probs) avg_probs.add_(probs)
if attn is not None: if attn is not None and torch.is_tensor(attn):
attn = attn.data attn = attn.data
if avg_attn is None: if avg_attn is None:
avg_attn = attn avg_attn = attn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment