Commit 37c9d96f authored by Louis MARTIN's avatar Louis MARTIN Committed by Facebook Github Bot
Browse files

Add whole word masking for SentencepieceBPE (#1292)

Summary:
Models seem to train fine with this modification. I checked that the mask for beginning of words is correct but didn't check if the actual masking worked correctly.
Pull Request resolved: https://github.com/pytorch/fairseq/pull/1292

Differential Revision: D18338307

Pulled By: myleott

fbshipit-source-id: eae9e29d6ab648e768d70921694a898554496704
parent 7ca56cb8
...@@ -31,3 +31,13 @@ class SentencepieceBPE(object): ...@@ -31,3 +31,13 @@ class SentencepieceBPE(object):
def decode(self, x: str) -> str: def decode(self, x: str) -> str:
return x.replace(' ', '').replace('\u2581', ' ').strip() return x.replace(' ', '').replace('\u2581', ' ').strip()
def is_beginning_of_word(self, x: str) -> bool:
if x in ['<unk>', '<s>', '</s>', '<pad>']:
# special elements are always considered beginnings
# HACK: this logic is already present in fairseq/tasks/masked_lm.py
# but these special tokens are also contained in the sentencepiece
# vocabulary which causes duplicate special tokens. This hack makes
# sure that they are all taken into account.
return True
return x.startswith('\u2581')
...@@ -108,7 +108,7 @@ class MaskedLMTask(FairseqTask): ...@@ -108,7 +108,7 @@ class MaskedLMTask(FairseqTask):
# create masked input and targets # create masked input and targets
if self.args.mask_whole_words: if self.args.mask_whole_words:
bpe = encoders.build_bpe(self.args) bpe = encoders.build_bpe(self.args)
if bpe is not None: assert bpe is not None
def is_beginning_of_word(i): def is_beginning_of_word(i):
if i < self.source_dictionary.nspecial: if i < self.source_dictionary.nspecial:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment