Commit 08e866f9 authored by Ruty Rinott's avatar Ruty Rinott Committed by Facebook Github Bot
Browse files

moving masking logic to collate

Summary: Move masking logic to data_utils

Reviewed By: kartikayk, jingfeidu

Differential Revision: D14098403

fbshipit-source-id: c7b7e811ab48b9c5a12662dc1e2f2ed694724176
parent 9998bbfa
...@@ -7,10 +7,50 @@ ...@@ -7,10 +7,50 @@
import contextlib import contextlib
import os import os
import math
import numpy as np import numpy as np
def mask(sentence, mask_id, pad_id, dictionary_token_range, mask_ratio=.15):
"""mask tokens for masked language model training
Samples mask_ratio tokens that will be predicted by LM.
- 80%: Replace the word with mask_id
- 10%: replate the word with random token within dictionary_token_range
- 10%: keeps word unchanged
This function may not be efficient enough since we had multiple conversions
between np and torch, we can replace them with torch operators later
Args:
sentence: 1d tensor to be masked
mask_id: index to use for masking the sentence
pad_id: index to use for masking the target for tokens we aren't predicting
dictionary_token_range: range of indices in dictionary which can be used
for random word replacement (e.g. without special characters)
mask_ratio: ratio of tokens to be masked in the sentence
Return:
masked_sent: masked sentence
target: target with words which we are not predicting replaced by pad_id
"""
masked_sent = np.copy(sentence)
sent_length = len(sentence)
mask_num = math.ceil(sent_length * mask_ratio)
mask = np.random.choice(sent_length, mask_num)
target = np.copy(sentence)
for i in range(sent_length):
if i in mask:
rand = np.random.random()
if rand < 0.8:
masked_sent[i] = mask_id
elif rand < 0.9:
# sample random token
masked_sent[i] = (
np.random.randint(dictionary_token_range[0], dictionary_token_range[1])
)
else:
target[i] = pad_id
return masked_sent, target
def infer_language_pair(path): def infer_language_pair(path):
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx""" """Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
src, dst = None, None src, dst = None, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment