Unverified Commit 1e2acd0d authored by Shashank Gupta's avatar Shashank Gupta Committed by GitHub
Browse files

Bug fix for permutation language modelling (#8409)

parent bf8625e7
...@@ -579,7 +579,7 @@ class DataCollatorForPermutationLanguageModeling: ...@@ -579,7 +579,7 @@ class DataCollatorForPermutationLanguageModeling:
masked_indices.masked_fill_(padding_mask, value=0.0) masked_indices.masked_fill_(padding_mask, value=0.0)
# Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc. # Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
non_func_mask = ~(padding_mask & special_tokens_mask) non_func_mask = ~(padding_mask | special_tokens_mask)
inputs[masked_indices] = self.tokenizer.mask_token_id inputs[masked_indices] = self.tokenizer.mask_token_id
labels[~masked_indices] = -100 # We only compute loss on masked tokens labels[~masked_indices] = -100 # We only compute loss on masked tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment