Commit 831b6b6e authored by Naman Goyal's avatar Naman Goyal Committed by Facebook Github Bot
Browse files

Bart fix prev tokens collate

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/920

Differential Revision: D18593088

fbshipit-source-id: d4479ee8dae2ca623e62e12bd145165a116fb70a
parent 534eaa2c
......@@ -139,7 +139,12 @@ class BARTHubInterface(nn.Module):
))
tokens.to(device=self.device),
prev_output_tokens = tokens.clone()
prev_output_tokens[:, 0] = tokens[:, -1]
prev_output_tokens[:, 0] = tokens.gather(
1,
(tokens.ne(self.task.source_dictionary.pad()).sum(dim=1)- 1).unsqueeze(-1),
).squeeze()
prev_output_tokens[:, 1:] = tokens[:, :-1]
features, extra = self.model(
src_tokens=tokens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment