"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e93ccb3290ec4fb0076495c86af9de33f27048bd"
Commit 5ed50a93 authored by LysandreJik's avatar LysandreJik
Browse files

LM finetuning won't mask special tokens anymore

parent cc412edd
...@@ -108,7 +108,12 @@ def mask_tokens(inputs, tokenizer, args): ...@@ -108,7 +108,12 @@ def mask_tokens(inputs, tokenizer, args):
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
labels = inputs.clone() labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).bool() probability_matrix = torch.full(labels.shape, args.mlm_probability)
probability_matrix *= torch.tensor(
[tokenizer.get_sequence_ids(val, special_tokens_present=True) for val in labels.tolist()],
dtype=torch.float
)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -1 # We only compute loss on masked tokens labels[~masked_indices] = -1 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment