Commit 2a681d99 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

make sure tensor used to index is cuda if on gpu

parent 88df72c0
...@@ -380,7 +380,7 @@ class SequenceGenerator(object): ...@@ -380,7 +380,7 @@ class SequenceGenerator(object):
new_bsz = bsz - len(finalized_sents) new_bsz = bsz - len(finalized_sents)
batch_mask = torch.ones(bsz).type_as(cand_indices) batch_mask = torch.ones(bsz).type_as(cand_indices)
batch_mask[torch.LongTensor(finalized_sents)] = 0 batch_mask[cand_indices.new(finalized_sents)] = 0
batch_idxs = batch_mask.nonzero().squeeze(-1) batch_idxs = batch_mask.nonzero().squeeze(-1)
eos_mask = eos_mask[batch_idxs] eos_mask = eos_mask[batch_idxs]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment