Commit 9dd87245 authored by linkerr's avatar linkerr Committed by Facebook Github Bot
Browse files

fixed torch 0.4.0 , "RuntimeError: Expected object of type torch.cuda… (#393)

Summary:
….LongTensor but found type torch.cuda.FloatTensor for argument #3 'index' " error

in the torch.__version__ == 0.4.0 ,
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
will return a float dtype Tensor, when exec the "line 321: fairseq/fairseq/models/fconv.py " will throw a RuntimeError
Pull Request resolved: https://github.com/pytorch/fairseq/pull/393

Differential Revision: D13276496

Pulled By: myleott

fbshipit-source-id: e7986246fbe2c79fff61bcab0e5bec9dd63e0afd
parent 7bbe528d
......@@ -139,7 +139,7 @@ class SequenceGenerator(object):
# compute the encoder output for each beam
encoder_out = model.encoder(**encoder_input)
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
new_order = new_order.to(src_tokens.device)
new_order = new_order.to(src_tokens.device).long()
encoder_out = model.encoder.reorder_encoder_out(encoder_out, new_order)
encoder_outs.append(encoder_out)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment