Unverified Commit 625318f5 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

tensor.nonzero() is deprecated in PyTorch 1.6 (#6715)


Signed-off-by: default avatarMorgan Funtowicz <funtowiczmo@gmail.com>
parent 124c3d6a
......@@ -1226,7 +1226,7 @@ class FillMaskPipeline(Pipeline):
values = tf.gather_nd(values, tf.reshape(sort_inds, (-1, 1))).numpy()
predictions = target_inds[sort_inds.numpy()]
else:
masked_index = (input_ids == self.tokenizer.mask_token_id).nonzero()
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False)
# Fill mask pipeline supports only one ${mask_token} per sample
self.ensure_exactly_one_mask_token(masked_index.numpy())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment