Commit 1ebfeb79 authored by Lysandre's avatar Lysandre
Browse files

Cast to long when masking tokens

parent 9c67196b
...@@ -195,6 +195,7 @@ def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) - ...@@ -195,6 +195,7 @@ def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -
def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> Tuple[torch.Tensor, torch.Tensor]: def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> Tuple[torch.Tensor, torch.Tensor]:
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
inputs = inputs.clone().type(dtype=torch.long)
labels = inputs.clone() labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
probability_matrix = torch.full(labels.shape, args.mlm_probability) probability_matrix = torch.full(labels.shape, args.mlm_probability)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment