Commit 20e7836e authored by Naman Goyal's avatar Naman Goyal Committed by Facebook Github Bot
Browse files

fixed arg passing in masked_lm_dataset

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/715

Differential Revision: D15240723

fbshipit-source-id: 11d7280cb187d68f107902822e878f2a04b840c7
parent e37bd948
...@@ -54,8 +54,6 @@ class MaskedLMDataset(FairseqDataset): ...@@ -54,8 +54,6 @@ class MaskedLMDataset(FairseqDataset):
replaced with the "MASK" token. replaced with the "MASK" token.
random_token_prob: specifies the probability of a given token being random_token_prob: specifies the probability of a given token being
replaced by a random token from the vocabulary. replaced by a random token from the vocabulary.
unchanged_prob: specifies the probability of keeping a given
token unchanged.
""" """
def __init__( def __init__(
...@@ -129,9 +127,6 @@ class MaskedLMDataset(FairseqDataset): ...@@ -129,9 +127,6 @@ class MaskedLMDataset(FairseqDataset):
mask_idx: int, mask_idx: int,
pad_idx: int, pad_idx: int,
dictionary_token_range: Tuple, dictionary_token_range: Tuple,
masking_ratio: float = 0.15,
masking_prob: float = 0.8,
random_token_prob: float = 0.1
): ):
""" """
Mask tokens for Masked Language Model training Mask tokens for Masked Language Model training
...@@ -149,12 +144,6 @@ class MaskedLMDataset(FairseqDataset): ...@@ -149,12 +144,6 @@ class MaskedLMDataset(FairseqDataset):
dictionary_token_range: range of indices in dictionary which can dictionary_token_range: range of indices in dictionary which can
be used for random word replacement be used for random word replacement
(e.g. without special characters) (e.g. without special characters)
masking_ratio: specifies what percentage of the blocks should be
masked.
masking_prob: specifies the probability of a given token being
replaced with the "MASK" token.
random_token_prob: specifies the probability of a given token being
replaced by a random token from the vocabulary
Return: Return:
masked_sent: masked sentence masked_sent: masked sentence
target: target with words which we are not predicting replaced target: target with words which we are not predicting replaced
...@@ -162,7 +151,7 @@ class MaskedLMDataset(FairseqDataset): ...@@ -162,7 +151,7 @@ class MaskedLMDataset(FairseqDataset):
""" """
masked_sent = np.copy(sentence) masked_sent = np.copy(sentence)
sent_length = len(sentence) sent_length = len(sentence)
mask_num = math.ceil(sent_length * masking_ratio) mask_num = math.ceil(sent_length * self.masking_ratio)
mask = np.random.choice(sent_length, mask_num) mask = np.random.choice(sent_length, mask_num)
target = np.copy(sentence) target = np.copy(sentence)
...@@ -172,12 +161,12 @@ class MaskedLMDataset(FairseqDataset): ...@@ -172,12 +161,12 @@ class MaskedLMDataset(FairseqDataset):
# replace with mask if probability is less than masking_prob # replace with mask if probability is less than masking_prob
# (Eg: 0.8) # (Eg: 0.8)
if rand < masking_prob: if rand < self.masking_prob:
masked_sent[i] = mask_idx masked_sent[i] = mask_idx
# replace with random token if probability is less than # replace with random token if probability is less than
# masking_prob + random_token_prob (Eg: 0.9) # masking_prob + random_token_prob (Eg: 0.9)
elif rand < (masking_prob + random_token_prob): elif rand < (self.masking_ratio + self.random_token_prob):
# sample random token from dictionary # sample random token from dictionary
masked_sent[i] = ( masked_sent[i] = (
np.random.randint( np.random.randint(
...@@ -229,8 +218,6 @@ class MaskedLMDataset(FairseqDataset): ...@@ -229,8 +218,6 @@ class MaskedLMDataset(FairseqDataset):
# mask according to specified probabilities. # mask according to specified probabilities.
masked_blk_one, masked_tgt_one = self._mask_block( masked_blk_one, masked_tgt_one = self._mask_block(
s["block_one"], self.mask_idx, self.pad_idx, token_range, s["block_one"], self.mask_idx, self.pad_idx, token_range,
masking_ratio=self.masking_ratio, masking_prob=self.masking_prob,
random_token_prob=self.random_token_prob,
) )
tokens = np.concatenate([ tokens = np.concatenate([
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment