Commit 20e7836e authored by Naman Goyal's avatar Naman Goyal Committed by Facebook Github Bot
Browse files

fixed arg passing in masked_lm_dataset

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/715

Differential Revision: D15240723

fbshipit-source-id: 11d7280cb187d68f107902822e878f2a04b840c7
parent e37bd948
......@@ -54,8 +54,6 @@ class MaskedLMDataset(FairseqDataset):
replaced with the "MASK" token.
random_token_prob: specifies the probability of a given token being
replaced by a random token from the vocabulary.
unchanged_prob: specifies the probability of keeping a given
token unchanged.
"""
def __init__(
......@@ -129,9 +127,6 @@ class MaskedLMDataset(FairseqDataset):
mask_idx: int,
pad_idx: int,
dictionary_token_range: Tuple,
masking_ratio: float = 0.15,
masking_prob: float = 0.8,
random_token_prob: float = 0.1
):
"""
Mask tokens for Masked Language Model training
......@@ -149,12 +144,6 @@ class MaskedLMDataset(FairseqDataset):
dictionary_token_range: range of indices in dictionary which can
be used for random word replacement
(e.g. without special characters)
masking_ratio: specifies what percentage of the blocks should be
masked.
masking_prob: specifies the probability of a given token being
replaced with the "MASK" token.
random_token_prob: specifies the probability of a given token being
replaced by a random token from the vocabulary
Return:
masked_sent: masked sentence
target: target with words which we are not predicting replaced
......@@ -162,7 +151,7 @@ class MaskedLMDataset(FairseqDataset):
"""
masked_sent = np.copy(sentence)
sent_length = len(sentence)
mask_num = math.ceil(sent_length * masking_ratio)
mask_num = math.ceil(sent_length * self.masking_ratio)
mask = np.random.choice(sent_length, mask_num)
target = np.copy(sentence)
......@@ -172,12 +161,12 @@ class MaskedLMDataset(FairseqDataset):
# replace with mask if probability is less than masking_prob
# (Eg: 0.8)
if rand < masking_prob:
if rand < self.masking_prob:
masked_sent[i] = mask_idx
# replace with random token if probability is less than
# masking_prob + random_token_prob (Eg: 0.9)
elif rand < (masking_prob + random_token_prob):
elif rand < (self.masking_ratio + self.random_token_prob):
# sample random token from dictionary
masked_sent[i] = (
np.random.randint(
......@@ -229,8 +218,6 @@ class MaskedLMDataset(FairseqDataset):
# mask according to specified probabilities.
masked_blk_one, masked_tgt_one = self._mask_block(
s["block_one"], self.mask_idx, self.pad_idx, token_range,
masking_ratio=self.masking_ratio, masking_prob=self.masking_prob,
random_token_prob=self.random_token_prob,
)
tokens = np.concatenate([
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment