Unverified Commit 6915174e authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[RobertaTokenizer] remove inheritance on GPT2Tokenizer (#15429)

* refactor roberta tokenizer

* refactor fast tokenizer

* remove old comment
parent a5ecbf73
...@@ -14,11 +14,15 @@ ...@@ -14,11 +14,15 @@
# limitations under the License. # limitations under the License.
"""Tokenization classes for RoBERTa.""" """Tokenization classes for RoBERTa."""
from typing import List, Optional import json
import os
from functools import lru_cache
from typing import List, Optional, Tuple
from ...tokenization_utils import AddedToken import regex as re
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging from ...utils import logging
from ..gpt2.tokenization_gpt2 import GPT2Tokenizer
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -57,7 +61,46 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { ...@@ -57,7 +61,46 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
} }
class RobertaTokenizer(GPT2Tokenizer): @lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
characters the bpe code barfs on.
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
tables between utf-8 bytes and unicode strings.
"""
bs = (
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2 ** 8):
if b not in bs:
bs.append(b)
cs.append(2 ** 8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""
Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class RobertaTokenizer(PreTrainedTokenizer):
""" """
Constructs a RoBERTa tokenizer, derived from the GPT-2 tokenizer, using byte-level Byte-Pair-Encoding. Constructs a RoBERTa tokenizer, derived from the GPT-2 tokenizer, using byte-level Byte-Pair-Encoding.
...@@ -164,8 +207,6 @@ class RobertaTokenizer(GPT2Tokenizer): ...@@ -164,8 +207,6 @@ class RobertaTokenizer(GPT2Tokenizer):
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
super().__init__( super().__init__(
vocab_file=vocab_file,
merges_file=merges_file,
errors=errors, errors=errors,
bos_token=bos_token, bos_token=bos_token,
eos_token=eos_token, eos_token=eos_token,
...@@ -178,6 +219,124 @@ class RobertaTokenizer(GPT2Tokenizer): ...@@ -178,6 +219,124 @@ class RobertaTokenizer(GPT2Tokenizer):
**kwargs, **kwargs,
) )
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
with open(merges_file, encoding="utf-8") as merges_handle:
bpe_merges = merges_handle.read().split("\n")[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
self.add_prefix_space = add_prefix_space
# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
@property
def vocab_size(self):
return len(self.encoder)
def get_vocab(self):
return dict(self.encoder, **self.added_tokens_encoder)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
except ValueError:
new_word.extend(word[i:])
break
else:
new_word.extend(word[i:j])
i = j
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def _tokenize(self, text):
"""Tokenize a string."""
bpe_tokens = []
for token in re.findall(self.pat, text):
token = "".join(
self.byte_encoder[b] for b in token.encode("utf-8")
) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.decoder.get(index)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
text = "".join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
return text
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
merge_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write("#version: 0.2\n")
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning(
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
)
index = token_index
writer.write(" ".join(bpe_tokens) + "\n")
index += 1
return vocab_file, merge_file
def build_inputs_with_special_tokens( def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]: ) -> List[int]:
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
# limitations under the License. # limitations under the License.
"""Fast Tokenization classes for RoBERTa.""" """Fast Tokenization classes for RoBERTa."""
import json import json
from typing import List, Optional from typing import List, Optional, Tuple
from tokenizers import processors from tokenizers import pre_tokenizers, processors
from ...tokenization_utils_base import AddedToken from ...tokenization_utils_base import AddedToken, BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging from ...utils import logging
from ..gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
from .tokenization_roberta import RobertaTokenizer from .tokenization_roberta import RobertaTokenizer
...@@ -65,7 +65,7 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { ...@@ -65,7 +65,7 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
} }
class RobertaTokenizerFast(GPT2TokenizerFast): class RobertaTokenizerFast(PreTrainedTokenizerFast):
""" """
Construct a "fast" RoBERTa tokenizer (backed by HuggingFace's *tokenizers* library), derived from the GPT-2 Construct a "fast" RoBERTa tokenizer (backed by HuggingFace's *tokenizers* library), derived from the GPT-2
tokenizer, using byte-level Byte-Pair-Encoding. tokenizer, using byte-level Byte-Pair-Encoding.
...@@ -184,7 +184,14 @@ class RobertaTokenizerFast(GPT2TokenizerFast): ...@@ -184,7 +184,14 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
**kwargs, **kwargs,
) )
# the pre_tokenizer is already updated in the GPT2TokenizerFast `__init__` pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
pre_tok_state["add_prefix_space"] = add_prefix_space
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
self.add_prefix_space = add_prefix_space
tokenizer_component = "post_processor" tokenizer_component = "post_processor"
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None) tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
if tokenizer_component_instance: if tokenizer_component_instance:
...@@ -237,6 +244,29 @@ class RobertaTokenizerFast(GPT2TokenizerFast): ...@@ -237,6 +244,29 @@ class RobertaTokenizerFast(GPT2TokenizerFast):
value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value
self._mask_token = value self._mask_token = value
def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
is_split_into_words = kwargs.get("is_split_into_words", False)
assert self.add_prefix_space or not is_split_into_words, (
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
"to use it with pretokenized inputs."
)
return super()._batch_encode_plus(*args, **kwargs)
def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
is_split_into_words = kwargs.get("is_split_into_words", False)
assert self.add_prefix_space or not is_split_into_words, (
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
"to use it with pretokenized inputs."
)
return super()._encode_plus(*args, **kwargs)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
return tuple(files)
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
if token_ids_1 is None: if token_ids_1 is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment