Commit 39f426be authored by LysandreJik's avatar LysandreJik
Browse files

Added special tokens <pad> and <mask> to RoBERTa.

parent baf08ca1
...@@ -279,7 +279,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): ...@@ -279,7 +279,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
sep_token=tokenizer.sep_token, sep_token=tokenizer.sep_token,
sep_token_extra=bool(args.model_type in ['roberta']), # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805 sep_token_extra=bool(args.model_type in ['roberta']), # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
pad_token=1 if args.model_type in ['roberta'] else 0, # TODO(Lysandre: replace with tokenizer.pad_token when implemented) pad_token=tokenizer.encoder[tokenizer.pad_token] if args.model_type in ['roberta'] else tokenizer.vocab[tokenizer.pad_token],
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0, pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0,
) )
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
......
...@@ -73,9 +73,10 @@ class RobertaTokenizer(PreTrainedTokenizer): ...@@ -73,9 +73,10 @@ class RobertaTokenizer(PreTrainedTokenizer):
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, merges_file, errors='replace', bos_token="<s>", eos_token="</s>", sep_token="</s>", def __init__(self, vocab_file, merges_file, errors='replace', bos_token="<s>", eos_token="</s>", sep_token="</s>",
cls_token="<s>", unk_token="<unk>", **kwargs): cls_token="<s>", unk_token="<unk>", pad_token='<pad>', mask_token='<mask>', **kwargs):
super(RobertaTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, super(RobertaTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token,
sep_token=sep_token, cls_token=cls_token, **kwargs) sep_token=sep_token, cls_token=cls_token, pad_token=pad_token,
mask_token=mask_token, **kwargs)
self.encoder = json.load(open(vocab_file, encoding="utf-8")) self.encoder = json.load(open(vocab_file, encoding="utf-8"))
self.decoder = {v: k for k, v in self.encoder.items()} self.decoder = {v: k for k, v in self.encoder.items()}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment