Commit 5218a7c9 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix compatibility with PyTorch 1.0.x (Fixes #906)

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/910

Differential Revision: D16536532

Pulled By: myleott

fbshipit-source-id: 56bb5570e70b5670ad87c64d9dd20c64c1fa9f5c
parent 40f16872
......@@ -5,7 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from torch.utils.data._utils.collate import default_collate
from torch.utils.data.dataloader import default_collate
from . import FairseqDataset
......
......@@ -129,7 +129,7 @@ class MaskTokensDataset(BaseWrapperDataset):
if self.mask_whole_words is not None:
mask = np.repeat(mask, word_lens)
new_item = np.full(len(mask), self.pad_idx)
new_item[mask] = item[torch.from_numpy(mask)]
new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8))]
return torch.from_numpy(new_item)
# decide unmasking and random replacement
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment