Commit 92a6c548 authored by Kartikay Khandelwal's avatar Kartikay Khandelwal Committed by Facebook Github Bot
Browse files

Refactor BERTDataset to the more general MaskedLMDataset

Summary: The current BERTDataset has a lot of components needed for generic MaskedLM training but is too restrictive in terms of the assumptions it makes - two blocks being masked, the special tokens used for the sentence embedding as well as the separator etc. In this diff I refactor this dataset and at the same time add make some of the parameters including the probabilities associated with masking configurable.

Reviewed By: rutyrinott

Differential Revision: D14222467

fbshipit-source-id: e9f78788dfe7f56646ba09c62967c4c0bd30aed8
parent 4d59517f
...@@ -7,50 +7,9 @@ ...@@ -7,50 +7,9 @@
import contextlib import contextlib
import os import os
import math
import numpy as np import numpy as np
def mask(sentence, mask_id, pad_id, dictionary_token_range, mask_ratio=.15):
"""mask tokens for masked language model training
Samples mask_ratio tokens that will be predicted by LM.
- 80%: Replace the word with mask_id
- 10%: replate the word with random token within dictionary_token_range
- 10%: keeps word unchanged
This function may not be efficient enough since we had multiple conversions
between np and torch, we can replace them with torch operators later
Args:
sentence: 1d tensor to be masked
mask_id: index to use for masking the sentence
pad_id: index to use for masking the target for tokens we aren't predicting
dictionary_token_range: range of indices in dictionary which can be used
for random word replacement (e.g. without special characters)
mask_ratio: ratio of tokens to be masked in the sentence
Return:
masked_sent: masked sentence
target: target with words which we are not predicting replaced by pad_id
"""
masked_sent = np.copy(sentence)
sent_length = len(sentence)
mask_num = math.ceil(sent_length * mask_ratio)
mask = np.random.choice(sent_length, mask_num)
target = np.copy(sentence)
for i in range(sent_length):
if i in mask:
rand = np.random.random()
if rand < 0.8:
masked_sent[i] = mask_id
elif rand < 0.9:
# sample random token
masked_sent[i] = (
np.random.randint(dictionary_token_range[0], dictionary_token_range[1])
)
else:
target[i] = pad_id
return masked_sent, target
def infer_language_pair(path): def infer_language_pair(path):
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx""" """Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
src, dst = None, None src, dst = None, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment