"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "d8e3bdbb4cce939e8f95e0f1fa33bdd7350f4b79"
Unverified Commit 2d506ea4 authored by anruijian's avatar anruijian Committed by GitHub
Browse files

Fix tf random token masking probability in data collator (#21834)

* fix tf random mask tokens probability

* fix tf random mask tokens probability in collator for langauge modelling
parent 4fe744f5
...@@ -679,7 +679,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): ...@@ -679,7 +679,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
inputs = tf.where(indices_replaced, mask_token_id, inputs) inputs = tf.where(indices_replaced, mask_token_id, inputs)
# 10% of the time, we replace masked input tokens with random word # 10% of the time, we replace masked input tokens with random word
indices_random = self.tf_bernoulli(input_shape, 0.1) & masked_indices & ~indices_replaced indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced
random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=tf.int64) random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=tf.int64)
inputs = tf.where(indices_random, random_words, inputs) inputs = tf.where(indices_random, random_words, inputs)
...@@ -1062,7 +1062,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): ...@@ -1062,7 +1062,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
inputs = tf.where(indices_replaced, self.tokenizer.mask_token_id, inputs) inputs = tf.where(indices_replaced, self.tokenizer.mask_token_id, inputs)
# 10% of the time, we replace masked input tokens with random word # 10% of the time, we replace masked input tokens with random word
indices_random = self.tf_bernoulli(input_shape, 0.1) & masked_indices & ~indices_replaced indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced
random_words = tf.random.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64) random_words = tf.random.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64)
inputs = tf.where(indices_random, random_words, inputs) inputs = tf.where(indices_random, random_words, inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment