Commit c38b1f91 authored by Taylan Bilal's avatar Taylan Bilal Committed by Facebook Github Bot
Browse files

removing tensor resizing in future_mask (#877)

Summary:
tensor resizing doesn't work well with tpus, this change is equivalent
to the base and works better w/ tpus.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/877

Differential Revision: D16241620

Pulled By: myleott

fbshipit-source-id: 402c7d5eb6175a66a0420d10e74eb0a9e085790e
parent 397ba265
...@@ -462,10 +462,8 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -462,10 +462,8 @@ class TransformerDecoder(FairseqIncrementalDecoder):
def buffered_future_mask(self, tensor): def buffered_future_mask(self, tensor):
dim = tensor.size(0) dim = tensor.size(0)
if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device: if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device or self._future_mask.size(0) < dim:
self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
if self._future_mask.size(0) < dim:
self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1)
return self._future_mask[:dim, :dim] return self._future_mask[:dim, :dim]
def upgrade_state_dict_named(self, state_dict, name): def upgrade_state_dict_named(self, state_dict, name):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment