"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "a303f07819e4c436172ea49647332cf60d5c4ec4"
Commit baa8ce11 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Updates for PyTorch 1.2 masking/bool behavior

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/821

Differential Revision: D16790120

Pulled By: myleott

fbshipit-source-id: 2fb5070172636561d08596a29f08c93df07548bf
parent d015d23a
...@@ -127,7 +127,7 @@ class MaskTokensDataset(BaseWrapperDataset): ...@@ -127,7 +127,7 @@ class MaskTokensDataset(BaseWrapperDataset):
if self.mask_whole_words is not None: if self.mask_whole_words is not None:
mask = np.repeat(mask, word_lens) mask = np.repeat(mask, word_lens)
new_item = np.full(len(mask), self.pad_idx) new_item = np.full(len(mask), self.pad_idx)
new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8))] new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1]
return torch.from_numpy(new_item) return torch.from_numpy(new_item)
# decide unmasking and random replacement # decide unmasking and random replacement
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment