"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "18ebd57bd80fdd1bb8cbf9af075ba0301705bc6f"
Commit b35d9bca authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix rearranging of encoder_out in SequenceGenerator

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/625

Differential Revision: D15595787

Pulled By: myleott

fbshipit-source-id: ba6edf305ed41be392194f492e034dd66d1743fe
parent 6a21b232
...@@ -304,7 +304,7 @@ class SequenceGenerator(object): ...@@ -304,7 +304,7 @@ class SequenceGenerator(object):
corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs) corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs)
reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size) reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size)
model.reorder_incremental_state(reorder_state) model.reorder_incremental_state(reorder_state)
model.reorder_encoder_out(encoder_outs, reorder_state) encoder_outs = model.reorder_encoder_out(encoder_outs, reorder_state)
lprobs, avg_attn_scores = model.forward_decoder( lprobs, avg_attn_scores = model.forward_decoder(
tokens[:, :step + 1], encoder_outs, temperature=self.temperature, tokens[:, :step + 1], encoder_outs, temperature=self.temperature,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment