Unverified Commit 21b3922e authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Authorize last version of tokenizer (#9799)



* Authorize last version of tokenizer

* Update version table

* Fix conversion of spm tokenizers and fix some hub links

* Bump tokenizers version to 0.10.1rc1

* Add script to check tokenizers conversion with XNLI

* Add some more mask_token lstrip support

* Must modify mask_token in slow tokenizers too

* Keep using the old method for Pegasus

* add missing import
Co-authored-by: default avatarAnthony MOI <m.anthony.moi@gmail.com>
parent d5888ef0
from collections import Counter
import datasets
import transformers
from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
from transformers.utils import logging
logging.set_verbosity_info()
TOKENIZER_CLASSES = {
name: (getattr(transformers, name), getattr(transformers, name + "Fast")) for name in SLOW_TO_FAST_CONVERTERS
}
dataset = datasets.load_dataset("xnli", split="test+validation")
total = 0
perfect = 0
imperfect = 0
wrong = 0
def check_diff(spm_diff, tok_diff, slow, fast):
if spm_diff == list(reversed(tok_diff)):
# AAA -> AA+A vs A+AA case.
return True
elif len(spm_diff) == len(tok_diff) and fast.decode(spm_diff) == fast.decode(tok_diff):
# Second order OK
# Barrich -> Barr + ich vs Bar + rich
return True
spm_reencoded = slow.encode(slow.decode(spm_diff))
tok_reencoded = fast.encode(fast.decode(spm_diff))
if spm_reencoded != spm_diff and spm_reencoded == tok_reencoded:
# Type 3 error.
# Snehagatha ->
# Sne, h, aga, th, a
# Sne, ha, gat, ha
# Encoding the wrong with sp does not even recover what spm gave us
# It fits tokenizer however...
return True
return False
def check_LTR_mark(line, idx, fast):
enc = fast.encode_plus(line)[0]
offsets = enc.offsets
curr, prev = offsets[idx], offsets[idx - 1]
if curr is not None and line[curr[0] : curr[1]] == "\u200f":
return True
if prev is not None and line[prev[0] : prev[1]] == "\u200f":
return True
def check_details(line, spm_ids, tok_ids, slow, fast):
# Encoding can be the same with same result AAA -> A + AA vs AA + A
# We can check that we use at least exactly the same number of tokens.
for i, (spm_id, tok_id) in enumerate(zip(spm_ids, tok_ids)):
if spm_id != tok_id:
break
first = i
for i, (spm_id, tok_id) in enumerate(zip(reversed(spm_ids), reversed(tok_ids))):
if spm_id != tok_id:
break
last = len(spm_ids) - i
spm_diff = spm_ids[first:last]
tok_diff = tok_ids[first:last]
if check_diff(spm_diff, tok_diff, slow, fast):
return True
if check_LTR_mark(line, first, fast):
return True
if last - first > 5:
# We might have twice a single problem, attempt to subdivide the disjointed tokens into smaller problems
spms = Counter(spm_ids[first:last])
toks = Counter(tok_ids[first:last])
removable_tokens = {spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si}
min_width = 3
for i in range(last - first - min_width):
if all(spm_ids[first + i + j] in removable_tokens for j in range(min_width)):
possible_matches = [
k
for k in range(last - first - min_width)
if tok_ids[first + k : first + k + min_width] == spm_ids[first + i : first + i + min_width]
]
for j in possible_matches:
if check_diff(spm_ids[first : first + i], tok_ids[first : first + j], sp, tok) and check_details(
line,
spm_ids[first + i : last],
tok_ids[first + j : last],
slow,
fast,
):
return True
print(f"Spm: {[fast.decode([spm_ids[i]]) for i in range(first, last)]}")
try:
print(f"Tok: {[fast.decode([tok_ids[i]]) for i in range(first, last)]}")
except Exception:
pass
ok_start = fast.decode(spm_ids[:first])
ok_end = fast.decode(spm_ids[last:])
wrong = fast.decode(spm_ids[first:last])
print()
print(wrong)
return False
def test_string(slow, fast, text):
global perfect
global imperfect
global wrong
global total
slow_ids = slow.encode(text)
fast_ids = fast.encode(text)
skip_assert = False
total += 1
if slow_ids != fast_ids:
if check_details(text, slow_ids, fast_ids, slow, fast):
skip_assert = True
imperfect += 1
else:
wrong += 1
else:
perfect += 1
if total % 10000 == 0:
print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})")
if skip_assert:
return
assert (
slow_ids == fast_ids
), f"line {text} : \n\n{slow_ids}\n{fast_ids}\n\n{slow.tokenize(text)}\n{fast.tokenize(text)}"
def test_tokenizer(slow, fast):
global batch_total
for i in range(len(dataset)):
# premise, all languages
for text in dataset[i]["premise"].values():
test_string(slow, fast, text)
# hypothesis, all languages
for text in dataset[i]["hypothesis"]["translation"]:
test_string(slow, fast, text)
if __name__ == "__main__":
for name, (slow_class, fast_class) in TOKENIZER_CLASSES.items():
checkpoint_names = list(slow_class.max_model_input_sizes.keys())
for checkpoint in checkpoint_names:
imperfect = 0
perfect = 0
wrong = 0
total = 0
print(f"========================== Checking {name}: {checkpoint} ==========================")
slow = slow_class.from_pretrained(checkpoint, force_download=True)
fast = fast_class.from_pretrained(checkpoint, force_download=True)
test_tokenizer(slow, fast)
print(f"Accuracy {perfect * 100 / total:.2f}")
...@@ -132,7 +132,7 @@ _deps = [ ...@@ -132,7 +132,7 @@ _deps = [
"tensorflow-cpu>=2.3", "tensorflow-cpu>=2.3",
"tensorflow>=2.3", "tensorflow>=2.3",
"timeout-decorator", "timeout-decorator",
"tokenizers==0.9.4", "tokenizers==0.10.1rc1",
"torch>=1.0", "torch>=1.0",
"tqdm>=4.27", "tqdm>=4.27",
"unidic>=1.0.2", "unidic>=1.0.2",
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import BPE, Unigram, WordPiece from tokenizers.models import BPE, Unigram, WordPiece
from .file_utils import requires_protobuf, requires_sentencepiece from .file_utils import requires_protobuf, requires_sentencepiece
...@@ -340,7 +340,12 @@ class SpmConverter(Converter): ...@@ -340,7 +340,12 @@ class SpmConverter(Converter):
def normalizer(self, proto): def normalizer(self, proto):
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
return normalizers.Precompiled(precompiled_charsmap) return normalizers.Sequence(
[normalizers.Precompiled(precompiled_charsmap), normalizers.Replace(Regex(" {2,}"), " ")]
)
def pre_tokenizer(self, replacement, add_prefix_space):
return pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
def post_processor(self): def post_processor(self):
return None return None
...@@ -353,12 +358,7 @@ class SpmConverter(Converter): ...@@ -353,12 +358,7 @@ class SpmConverter(Converter):
replacement = "▁" replacement = "▁"
add_prefix_space = True add_prefix_space = True
tokenizer.pre_tokenizer = pre_tokenizers.Sequence( tokenizer.pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space)
[
pre_tokenizers.WhitespaceSplit(),
pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space),
]
)
tokenizer.decoder = decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space) tokenizer.decoder = decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
post_processor = self.post_processor() post_processor = self.post_processor()
if post_processor: if post_processor:
...@@ -375,7 +375,11 @@ class AlbertConverter(SpmConverter): ...@@ -375,7 +375,11 @@ class AlbertConverter(SpmConverter):
] ]
def normalizer(self, proto): def normalizer(self, proto):
list_normalizers = [normalizers.Replace("``", '"'), normalizers.Replace("''", '"')] list_normalizers = [
normalizers.Replace("``", '"'),
normalizers.Replace("''", '"'),
normalizers.Replace(Regex(" {2,}"), " "),
]
if not self.original_tokenizer.keep_accents: if not self.original_tokenizer.keep_accents:
list_normalizers.append(normalizers.NFKD()) list_normalizers.append(normalizers.NFKD())
list_normalizers.append(normalizers.StripAccents()) list_normalizers.append(normalizers.StripAccents())
...@@ -529,7 +533,11 @@ class XLNetConverter(SpmConverter): ...@@ -529,7 +533,11 @@ class XLNetConverter(SpmConverter):
] ]
def normalizer(self, proto): def normalizer(self, proto):
list_normalizers = [normalizers.Replace("``", '"'), normalizers.Replace("''", '"')] list_normalizers = [
normalizers.Replace("``", '"'),
normalizers.Replace("''", '"'),
normalizers.Replace(Regex(" {2,}"), " "),
]
if not self.original_tokenizer.keep_accents: if not self.original_tokenizer.keep_accents:
list_normalizers.append(normalizers.NFKD()) list_normalizers.append(normalizers.NFKD())
list_normalizers.append(normalizers.StripAccents()) list_normalizers.append(normalizers.StripAccents())
...@@ -574,6 +582,14 @@ class PegasusConverter(SpmConverter): ...@@ -574,6 +582,14 @@ class PegasusConverter(SpmConverter):
def unk_id(self, proto): def unk_id(self, proto):
return proto.trainer_spec.unk_id + self.original_tokenizer.offset return proto.trainer_spec.unk_id + self.original_tokenizer.offset
def pre_tokenizer(self, replacement, add_prefix_space):
return pre_tokenizers.Sequence(
[
pre_tokenizers.WhitespaceSplit(),
pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space),
]
)
def post_processor(self): def post_processor(self):
eos = self.original_tokenizer.eos_token eos = self.original_tokenizer.eos_token
special_tokens = [ special_tokens = [
......
...@@ -45,7 +45,7 @@ deps = { ...@@ -45,7 +45,7 @@ deps = {
"tensorflow-cpu": "tensorflow-cpu>=2.3", "tensorflow-cpu": "tensorflow-cpu>=2.3",
"tensorflow": "tensorflow>=2.3", "tensorflow": "tensorflow>=2.3",
"timeout-decorator": "timeout-decorator", "timeout-decorator": "timeout-decorator",
"tokenizers": "tokenizers==0.9.4", "tokenizers": "tokenizers==0.10.1rc1",
"torch": "torch>=1.0", "torch": "torch>=1.0",
"tqdm": "tqdm>=4.27", "tqdm": "tqdm>=4.27",
"unidic": "unidic>=1.0.2", "unidic": "unidic>=1.0.2",
......
...@@ -22,7 +22,7 @@ from typing import List, Optional, Tuple ...@@ -22,7 +22,7 @@ from typing import List, Optional, Tuple
import sentencepiece as spm import sentencepiece as spm
from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging from ...utils import logging
...@@ -127,6 +127,9 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -127,6 +127,9 @@ class AlbertTokenizer(PreTrainedTokenizer):
mask_token="[MASK]", mask_token="[MASK]",
**kwargs **kwargs
): ):
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
super().__init__( super().__init__(
do_lower_case=do_lower_case, do_lower_case=do_lower_case,
remove_space=remove_space, remove_space=remove_space,
......
...@@ -20,6 +20,7 @@ from shutil import copyfile ...@@ -20,6 +20,7 @@ from shutil import copyfile
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ...file_utils import is_sentencepiece_available from ...file_utils import is_sentencepiece_available
from ...tokenization_utils import AddedToken
from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging from ...utils import logging
...@@ -134,6 +135,9 @@ class AlbertTokenizerFast(PreTrainedTokenizerFast): ...@@ -134,6 +135,9 @@ class AlbertTokenizerFast(PreTrainedTokenizerFast):
mask_token="[MASK]", mask_token="[MASK]",
**kwargs **kwargs
): ):
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
super().__init__( super().__init__(
vocab_file, vocab_file,
tokenizer_file=tokenizer_file, tokenizer_file=tokenizer_file,
......
...@@ -21,7 +21,7 @@ from typing import List, Optional, Tuple ...@@ -21,7 +21,7 @@ from typing import List, Optional, Tuple
import sentencepiece as spm import sentencepiece as spm
from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging from ...utils import logging
...@@ -112,6 +112,9 @@ class BarthezTokenizer(PreTrainedTokenizer): ...@@ -112,6 +112,9 @@ class BarthezTokenizer(PreTrainedTokenizer):
mask_token="<mask>", mask_token="<mask>",
**kwargs **kwargs
): ):
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
super().__init__( super().__init__(
bos_token=bos_token, bos_token=bos_token,
eos_token=eos_token, eos_token=eos_token,
......
...@@ -20,6 +20,7 @@ from shutil import copyfile ...@@ -20,6 +20,7 @@ from shutil import copyfile
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ...file_utils import is_sentencepiece_available from ...file_utils import is_sentencepiece_available
from ...tokenization_utils import AddedToken
from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging from ...utils import logging
...@@ -119,6 +120,9 @@ class BarthezTokenizerFast(PreTrainedTokenizerFast): ...@@ -119,6 +120,9 @@ class BarthezTokenizerFast(PreTrainedTokenizerFast):
mask_token="<mask>", mask_token="<mask>",
**kwargs **kwargs
): ):
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
super().__init__( super().__init__(
vocab_file, vocab_file,
tokenizer_file=tokenizer_file, tokenizer_file=tokenizer_file,
......
...@@ -21,7 +21,7 @@ from typing import List, Optional, Tuple ...@@ -21,7 +21,7 @@ from typing import List, Optional, Tuple
import sentencepiece as spm import sentencepiece as spm
from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging from ...utils import logging
...@@ -116,6 +116,9 @@ class CamembertTokenizer(PreTrainedTokenizer): ...@@ -116,6 +116,9 @@ class CamembertTokenizer(PreTrainedTokenizer):
additional_special_tokens=["<s>NOTUSED", "</s>NOTUSED"], additional_special_tokens=["<s>NOTUSED", "</s>NOTUSED"],
**kwargs **kwargs
): ):
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
super().__init__( super().__init__(
bos_token=bos_token, bos_token=bos_token,
eos_token=eos_token, eos_token=eos_token,
......
...@@ -20,6 +20,7 @@ from shutil import copyfile ...@@ -20,6 +20,7 @@ from shutil import copyfile
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ...file_utils import is_sentencepiece_available from ...file_utils import is_sentencepiece_available
from ...tokenization_utils import AddedToken
from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging from ...utils import logging
...@@ -123,6 +124,9 @@ class CamembertTokenizerFast(PreTrainedTokenizerFast): ...@@ -123,6 +124,9 @@ class CamembertTokenizerFast(PreTrainedTokenizerFast):
additional_special_tokens=["<s>NOTUSED", "</s>NOTUSED"], additional_special_tokens=["<s>NOTUSED", "</s>NOTUSED"],
**kwargs **kwargs
): ):
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
super().__init__( super().__init__(
vocab_file, vocab_file,
tokenizer_file=tokenizer_file, tokenizer_file=tokenizer_file,
......
...@@ -27,7 +27,7 @@ SPIECE_UNDERLINE = "▁" ...@@ -27,7 +27,7 @@ SPIECE_UNDERLINE = "▁"
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {"google/pegasus-xsum": "https://cdn.huggingface.co/google/pegasus-xsum/spiece.model"} "vocab_file": {"google/pegasus-xsum": "https://huggingface.co/google/pegasus-xsum/resolve/main/spiece.model"}
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
......
...@@ -38,8 +38,10 @@ SPIECE_UNDERLINE = "▁" ...@@ -38,8 +38,10 @@ SPIECE_UNDERLINE = "▁"
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"} VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {"google/pegasus-xsum": "https://cdn.huggingface.co/google/pegasus-xsum/spiece.model"}, "vocab_file": {"google/pegasus-xsum": "https://huggingface.co/google/pegasus-xsum/resolve/main/spiece.model"},
"tokenizer_file": {"google/pegasus-xsum": "https://cdn.huggingface.co/google/pegasus-xsum/tokenizer.json"}, "tokenizer_file": {
"google/pegasus-xsum": "https://huggingface.co/google/pegasus-xsum/resolve/main/tokenizer.json"
},
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
......
...@@ -42,7 +42,7 @@ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"} ...@@ -42,7 +42,7 @@ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
#################################################### ####################################################
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": { "vocab_file": {
"google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/spiece.model" "google/reformer-crime-and-punishment": "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/spiece.model"
} }
} }
......
...@@ -47,10 +47,10 @@ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer. ...@@ -47,10 +47,10 @@ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.
#################################################### ####################################################
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": { "vocab_file": {
"google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/spiece.model" "google/reformer-crime-and-punishment": "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/spiece.model"
}, },
"tokenizer_file": { "tokenizer_file": {
"google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/tokenizer.json" "google/reformer-crime-and-punishment": "https://huggingface.co/google/reformer-crime-and-punishment/resolve/main/tokenizer.json"
}, },
} }
......
...@@ -21,7 +21,7 @@ from typing import List, Optional, Tuple ...@@ -21,7 +21,7 @@ from typing import List, Optional, Tuple
import sentencepiece as spm import sentencepiece as spm
from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging from ...utils import logging
...@@ -117,6 +117,9 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): ...@@ -117,6 +117,9 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
mask_token="<mask>", mask_token="<mask>",
**kwargs **kwargs
): ):
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
super().__init__( super().__init__(
bos_token=bos_token, bos_token=bos_token,
eos_token=eos_token, eos_token=eos_token,
......
...@@ -20,6 +20,7 @@ from shutil import copyfile ...@@ -20,6 +20,7 @@ from shutil import copyfile
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ...file_utils import is_sentencepiece_available from ...file_utils import is_sentencepiece_available
from ...tokenization_utils import AddedToken
from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging from ...utils import logging
...@@ -127,6 +128,9 @@ class XLMRobertaTokenizerFast(PreTrainedTokenizerFast): ...@@ -127,6 +128,9 @@ class XLMRobertaTokenizerFast(PreTrainedTokenizerFast):
mask_token="<mask>", mask_token="<mask>",
**kwargs **kwargs
): ):
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
super().__init__( super().__init__(
vocab_file, vocab_file,
tokenizer_file=tokenizer_file, tokenizer_file=tokenizer_file,
......
...@@ -23,7 +23,7 @@ from typing import List, Optional, Tuple ...@@ -23,7 +23,7 @@ from typing import List, Optional, Tuple
import sentencepiece as spm import sentencepiece as spm
from ...file_utils import SPIECE_UNDERLINE from ...file_utils import SPIECE_UNDERLINE
from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging from ...utils import logging
...@@ -126,6 +126,9 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -126,6 +126,9 @@ class XLNetTokenizer(PreTrainedTokenizer):
additional_special_tokens=["<eop>", "<eod>"], additional_special_tokens=["<eop>", "<eod>"],
**kwargs **kwargs
): ):
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
super().__init__( super().__init__(
do_lower_case=do_lower_case, do_lower_case=do_lower_case,
remove_space=remove_space, remove_space=remove_space,
......
...@@ -20,6 +20,7 @@ from shutil import copyfile ...@@ -20,6 +20,7 @@ from shutil import copyfile
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ...file_utils import is_sentencepiece_available from ...file_utils import is_sentencepiece_available
from ...tokenization_utils import AddedToken
from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging from ...utils import logging
...@@ -138,6 +139,9 @@ class XLNetTokenizerFast(PreTrainedTokenizerFast): ...@@ -138,6 +139,9 @@ class XLNetTokenizerFast(PreTrainedTokenizerFast):
additional_special_tokens=["<eop>", "<eod>"], additional_special_tokens=["<eop>", "<eod>"],
**kwargs **kwargs
): ):
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
super().__init__( super().__init__(
vocab_file=vocab_file, vocab_file=vocab_file,
tokenizer_file=tokenizer_file, tokenizer_file=tokenizer_file,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment