"test/vscode:/vscode.git/clone" did not exist on "fc0e3b91744bef277cd8e4f68736f2970d0629d8"
Unverified Commit 2d506ea4 authored by anruijian's avatar anruijian Committed by GitHub
Browse files

Fix tf random token masking probability in data collator (#21834)

* fix tf random mask tokens probability

* fix tf random mask tokens probability in collator for langauge modelling
parent 4fe744f5
......@@ -679,7 +679,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
inputs = tf.where(indices_replaced, mask_token_id, inputs)
# 10% of the time, we replace masked input tokens with random word
indices_random = self.tf_bernoulli(input_shape, 0.1) & masked_indices & ~indices_replaced
indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced
random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=tf.int64)
inputs = tf.where(indices_random, random_words, inputs)
......@@ -1062,7 +1062,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
inputs = tf.where(indices_replaced, self.tokenizer.mask_token_id, inputs)
# 10% of the time, we replace masked input tokens with random word
indices_random = self.tf_bernoulli(input_shape, 0.1) & masked_indices & ~indices_replaced
indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced
random_words = tf.random.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64)
inputs = tf.where(indices_random, random_words, inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment