Commit 37c9d96f authored by Louis MARTIN's avatar Louis MARTIN Committed by Facebook Github Bot
Browse files

Add whole word masking for SentencepieceBPE (#1292)

Summary:
Models seem to train fine with this modification. I checked that the mask for beginning of words is correct but didn't check if the actual masking worked correctly.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/1292

Differential Revision: D18338307

Pulled By: myleott

fbshipit-source-id: eae9e29d6ab648e768d70921694a898554496704
parent 7ca56cb8
......@@ -31,3 +31,13 @@ class SentencepieceBPE(object):
def decode(self, x: str) -> str:
return x.replace(' ', '').replace('\u2581', ' ').strip()
def is_beginning_of_word(self, x: str) -> bool:
if x in ['<unk>', '<s>', '</s>', '<pad>']:
# special elements are always considered beginnings
# HACK: this logic is already present in fairseq/tasks/masked_lm.py
# but these special tokens are also contained in the sentencepiece
# vocabulary which causes duplicate special tokens. This hack makes
# sure that they are all taken into account.
return True
return x.startswith('\u2581')
......@@ -108,23 +108,23 @@ class MaskedLMTask(FairseqTask):
# create masked input and targets
if self.args.mask_whole_words:
bpe = encoders.build_bpe(self.args)
if bpe is not None:
def is_beginning_of_word(i):
if i < self.source_dictionary.nspecial:
# special elements are always considered beginnings
return True
tok = self.source_dictionary[i]
if tok.startswith('madeupword'):
return True
try:
return bpe.is_beginning_of_word(tok)
except ValueError:
return True
mask_whole_words = torch.ByteTensor(list(
map(is_beginning_of_word, range(len(self.source_dictionary)))
))
assert bpe is not None
def is_beginning_of_word(i):
if i < self.source_dictionary.nspecial:
# special elements are always considered beginnings
return True
tok = self.source_dictionary[i]
if tok.startswith('madeupword'):
return True
try:
return bpe.is_beginning_of_word(tok)
except ValueError:
return True
mask_whole_words = torch.ByteTensor(list(
map(is_beginning_of_word, range(len(self.source_dictionary)))
))
else:
mask_whole_words = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment