Unverified Commit 1670be4b authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding Llama FastTokenizer support. (#22264)

* Adding Llama FastTokenizer support.

- Requires https://github.com/huggingface/tokenizers/pull/1183 version
- Only support byte_fallback for llama, raise otherwise (safety net).
- Lots of questions are special tokens

How to test:

```python

from transformers.convert_slow_tokenizer import convert_slow_tokenizer
from transformers import AutoTokenizer
from tokenizers import Tokenizer

tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")

if False:
    new_tokenizer = Tokenizer.from_file("tok.json")
else:
    new_tokenizer = convert_slow_tokenizer(tokenizer)
    new_tokenizer.save("tok.json")

strings = [
    "This is a test",
    "生活的真谛是",
    "生活的真谛是[MASK]。",
    # XXX: This one is problematic because of special tokens
    # "<s> Something something",
]

for string in strings:
    encoded = tokenizer(string)["input_ids"]
    encoded2 = new_tokenizer.encode(string).ids

    assert encoded == encoded2, f"{encoded} != {encoded2}"

    decoded = tokenizer.decode(encoded)
    decoded2 = new_tokenizer.decode(encoded2)

    assert decoded.strip() == decoded2, f"{repr(decoded)} != {repr(decoded2)}"
```

The converter + some test script.

The test script.

Tmp save.

Adding Fast tokenizer + tests.

Adding the tokenization tests.

Correct combination.

Small fix.

Fixing tests.

Fixing with latest update.

Rebased.

fix copies + normalized added tokens  + copies.

Adding doc.

TMP.

Doc + split files.

Doc.

Versions + try import.

Fix Camembert + warnings -> Error.

Fix by ArthurZucker.

Not a decorator.

* Fixing comments.

* Adding more to docstring.

* Doc rewriting.
parent 15641892
...@@ -336,7 +336,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -336,7 +336,7 @@ Flax), PyTorch, and/or TensorFlow.
| LED | ✅ | ✅ | ✅ | ✅ | ❌ | | LED | ✅ | ✅ | ✅ | ✅ | ❌ |
| LeViT | ❌ | ❌ | ✅ | ❌ | ❌ | | LeViT | ❌ | ❌ | ✅ | ❌ | ❌ |
| LiLT | ❌ | ❌ | ✅ | ❌ | ❌ | | LiLT | ❌ | ❌ | ✅ | ❌ | ❌ |
| LLaMA | ✅ | | ✅ | ❌ | ❌ | | LLaMA | ✅ | | ✅ | ❌ | ❌ |
| Longformer | ✅ | ✅ | ✅ | ✅ | ❌ | | Longformer | ✅ | ✅ | ✅ | ✅ | ❌ |
| LongT5 | ❌ | ❌ | ✅ | ❌ | ✅ | | LongT5 | ❌ | ❌ | ✅ | ❌ | ✅ |
| LUKE | ✅ | ❌ | ✅ | ❌ | ❌ | | LUKE | ✅ | ❌ | ✅ | ❌ | ❌ |
......
...@@ -59,6 +59,14 @@ This model was contributed by [zphang](https://huggingface.co/zphang) with contr ...@@ -59,6 +59,14 @@ This model was contributed by [zphang](https://huggingface.co/zphang) with contr
- create_token_type_ids_from_sequences - create_token_type_ids_from_sequences
- save_vocabulary - save_vocabulary
## LlamaTokenizerFast
[[autodoc]] LlamaTokenizerFast
- build_inputs_with_special_tokens
- get_special_tokens_mask
- create_token_type_ids_from_sequences
- save_vocabulary
## LlamaModel ## LlamaModel
[[autodoc]] LlamaModel [[autodoc]] LlamaModel
......
...@@ -78,7 +78,7 @@ import re ...@@ -78,7 +78,7 @@ import re
import shutil import shutil
from pathlib import Path from pathlib import Path
from setuptools import setup, Command from setuptools import Command, setup
# Remove stale transformers.egg-info directory to avoid https://github.com/pypa/pip/issues/5466 # Remove stale transformers.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
...@@ -251,6 +251,7 @@ class DepsTableUpdateCommand(Command): ...@@ -251,6 +251,7 @@ class DepsTableUpdateCommand(Command):
with open(target, "w", encoding="utf-8", newline="\n") as f: with open(target, "w", encoding="utf-8", newline="\n") as f:
f.write("\n".join(content)) f.write("\n".join(content))
extras = {} extras = {}
extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic", "sudachipy", "sudachidict_core", "rhoknp") extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic", "sudachipy", "sudachidict_core", "rhoknp")
......
...@@ -740,6 +740,7 @@ else: ...@@ -740,6 +740,7 @@ else:
_import_structure["models.layoutlmv3"].append("LayoutLMv3TokenizerFast") _import_structure["models.layoutlmv3"].append("LayoutLMv3TokenizerFast")
_import_structure["models.layoutxlm"].append("LayoutXLMTokenizerFast") _import_structure["models.layoutxlm"].append("LayoutXLMTokenizerFast")
_import_structure["models.led"].append("LEDTokenizerFast") _import_structure["models.led"].append("LEDTokenizerFast")
_import_structure["models.llama"].append("LlamaTokenizerFast")
_import_structure["models.longformer"].append("LongformerTokenizerFast") _import_structure["models.longformer"].append("LongformerTokenizerFast")
_import_structure["models.lxmert"].append("LxmertTokenizerFast") _import_structure["models.lxmert"].append("LxmertTokenizerFast")
_import_structure["models.markuplm"].append("MarkupLMTokenizerFast") _import_structure["models.markuplm"].append("MarkupLMTokenizerFast")
...@@ -4388,6 +4389,7 @@ if TYPE_CHECKING: ...@@ -4388,6 +4389,7 @@ if TYPE_CHECKING:
from .models.layoutlmv3 import LayoutLMv3TokenizerFast from .models.layoutlmv3 import LayoutLMv3TokenizerFast
from .models.layoutxlm import LayoutXLMTokenizerFast from .models.layoutxlm import LayoutXLMTokenizerFast
from .models.led import LEDTokenizerFast from .models.led import LEDTokenizerFast
from .models.llama import LlamaTokenizerFast
from .models.longformer import LongformerTokenizerFast from .models.longformer import LongformerTokenizerFast
from .models.lxmert import LxmertTokenizerFast from .models.lxmert import LxmertTokenizerFast
from .models.markuplm import MarkupLMTokenizerFast from .models.markuplm import MarkupLMTokenizerFast
......
...@@ -19,10 +19,9 @@ All the conversions are grouped here to gather SentencePiece dependencies outsid ...@@ -19,10 +19,9 @@ All the conversions are grouped here to gather SentencePiece dependencies outsid
allow to make our dependency on SentencePiece optional. allow to make our dependency on SentencePiece optional.
""" """
import warnings
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import BPE, Unigram, WordPiece from tokenizers.models import BPE, Unigram, WordPiece
from .utils import requires_backends from .utils import requires_backends
...@@ -450,12 +449,13 @@ class SpmConverter(Converter): ...@@ -450,12 +449,13 @@ class SpmConverter(Converter):
self.proto = m self.proto = m
if self.proto.trainer_spec.byte_fallback: if self.proto.trainer_spec.byte_fallback:
warnings.warn( if not getattr(self, "handle_byte_fallback", None):
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option" raise RuntimeError(
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the" "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
" tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these " " which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
"unknown tokens into a sequence of byte tokens matching the original piece of text." " tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
) "unknown tokens into a sequence of byte tokens matching the original piece of text."
)
def vocab(self, proto): def vocab(self, proto):
return [(piece.piece, piece.score) for piece in proto.pieces] return [(piece.piece, piece.score) for piece in proto.pieces]
...@@ -1094,6 +1094,78 @@ class XGLMConverter(SpmConverter): ...@@ -1094,6 +1094,78 @@ class XGLMConverter(SpmConverter):
) )
class LlamaConverter(SpmConverter):
handle_byte_fallback = True
def vocab(self, proto):
vocab = [
("<unk>", 0.0),
("<s>", 0.0),
("</s>", 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
return vocab
def unk_id(self, proto):
unk_id = 0
return unk_id
def decoder(self, replacement, add_prefix_space):
return decoders.Sequence(
[
decoders.Replace("▁", " "),
decoders.ByteFallback(),
decoders.Fuse(),
decoders.Strip(content=" ", left=1),
]
)
def tokenizer(self, proto):
model_type = proto.trainer_spec.model_type
vocab_scores = self.vocab(proto)
if model_type == 1:
raise RuntimeError("Llama is supposed to be a BPE model!")
elif model_type == 2:
_, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
tokenizer = Tokenizer(
BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
)
tokenizer.add_special_tokens(
[
AddedToken("<unk>", normalized=True),
AddedToken("<s>", normalized=True),
AddedToken("</s>", normalized=True),
]
)
else:
raise Exception(
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
)
return tokenizer
def normalizer(self, proto):
return normalizers.Sequence(
[
normalizers.Prepend(prepend="▁"),
normalizers.Replace(pattern=" ", content="▁"),
]
)
def pre_tokenizer(self, replacement, add_prefix_space):
return None
def post_processor(self):
return processors.TemplateProcessing(
single="<s> $A",
pair="<s> $A $B",
special_tokens=[
("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
],
)
class MarkupLMConverter(Converter): class MarkupLMConverter(Converter):
def converted(self) -> Tokenizer: def converted(self) -> Tokenizer:
ot = self.original_tokenizer ot = self.original_tokenizer
...@@ -1183,6 +1255,7 @@ SLOW_TO_FAST_CONVERTERS = { ...@@ -1183,6 +1255,7 @@ SLOW_TO_FAST_CONVERTERS = {
"XLNetTokenizer": XLNetConverter, "XLNetTokenizer": XLNetConverter,
"SplinterTokenizer": SplinterConverter, "SplinterTokenizer": SplinterConverter,
"XGLMTokenizer": XGLMConverter, "XGLMTokenizer": XGLMConverter,
"LlamaTokenizer": LlamaConverter,
} }
......
...@@ -172,7 +172,13 @@ else: ...@@ -172,7 +172,13 @@ else:
("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)), ("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)), ("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
("lilt", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)), ("lilt", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
("llama", ("LlamaTokenizer" if is_sentencepiece_available() else None, None)), (
"llama",
(
"LlamaTokenizer" if is_sentencepiece_available() else None,
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)), ("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
( (
"longt5", "longt5",
......
...@@ -17,6 +17,7 @@ from ...utils import ( ...@@ -17,6 +17,7 @@ from ...utils import (
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule, _LazyModule,
is_sentencepiece_available, is_sentencepiece_available,
is_tokenizers_available,
is_torch_available, is_torch_available,
) )
...@@ -33,6 +34,14 @@ except OptionalDependencyNotAvailable: ...@@ -33,6 +34,14 @@ except OptionalDependencyNotAvailable:
else: else:
_import_structure["tokenization_llama"] = ["LlamaTokenizer"] _import_structure["tokenization_llama"] = ["LlamaTokenizer"]
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_llama_fast"] = ["LlamaTokenizerFast"]
try: try:
if not is_torch_available(): if not is_torch_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
...@@ -58,6 +67,14 @@ if TYPE_CHECKING: ...@@ -58,6 +67,14 @@ if TYPE_CHECKING:
else: else:
from .tokenization_llama import LlamaTokenizer from .tokenization_llama import LlamaTokenizer
try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_llama_fast import LlamaTokenizerFast
try: try:
if not is_torch_available(): if not is_torch_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
......
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils.versions import require_version
require_version("tokenizers>=0.13.3")
class LlamaTokenizerFast(PreTrainedTokenizerFast):
"""
Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
This uses notably ByteFallback and no normalization.
```
from transformers import LlamaTokenizerFast
tokenizer = LlaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
tokenizer.encode("Hello this is a test")
>>> [1, 15043, 445, 338, 263, 1243]
```
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods.
Args:
vocab_file (`str`):
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
contains the vocabulary necessary to instantiate a tokenizer.
tokenizer_file (`str`):
[tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
contains everything needed to load the tokenizer.
clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`):
Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
spaces.
bos_token (`str`, *optional*, defaults to `"<s>"`):
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
eos_token (`str`, *optional*, defaults to `"</s>"`):
The end of sequence token.
unk_token (`str`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
"""
padding_side = "left"
def __init__(
self,
vocab_file=None,
tokenizer_file=None,
clean_up_tokenization_spaces=False,
unk_token="<unk>",
bos_token="<s>",
eos_token="</s>",
**kwargs,
):
super().__init__(
vocab_file=vocab_file,
tokenizer_file=tokenizer_file,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
**kwargs,
)
...@@ -219,6 +219,13 @@ class LEDTokenizerFast(metaclass=DummyObject): ...@@ -219,6 +219,13 @@ class LEDTokenizerFast(metaclass=DummyObject):
requires_backends(self, ["tokenizers"]) requires_backends(self, ["tokenizers"])
class LlamaTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])
class LongformerTokenizerFast(metaclass=DummyObject): class LongformerTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"] _backends = ["tokenizers"]
......
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved. # Copyright 2023 The HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -23,8 +24,10 @@ from transformers import ( ...@@ -23,8 +24,10 @@ from transformers import (
SPIECE_UNDERLINE, SPIECE_UNDERLINE,
AddedToken, AddedToken,
LlamaTokenizer, LlamaTokenizer,
LlamaTokenizerFast,
is_torch_available, is_torch_available,
) )
from transformers.convert_slow_tokenizer import convert_slow_tokenizer
from transformers.testing_utils import ( from transformers.testing_utils import (
get_tests_dir, get_tests_dir,
nested_simplify, nested_simplify,
...@@ -287,13 +290,11 @@ class LlamaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -287,13 +290,11 @@ class LlamaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
class LlamaIntegrationTest(unittest.TestCase): class LlamaIntegrationTest(unittest.TestCase):
checkpoint_name = "hf-internal-testing/llama-tokenizer"
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.tokenizer: LlamaTokenizer = LlamaTokenizer.from_pretrained(cls.checkpoint_name) checkpoint_name = "hf-internal-testing/llama-tokenizer"
cls.rust_tokenizer = cls.tokenizer # TODO @narsil replace with the rust one cls.tokenizer: LlamaTokenizer = LlamaTokenizer.from_pretrained(checkpoint_name)
cls.pad_token_id = 1 cls.rust_tokenizer = LlamaTokenizerFast.from_pretrained(checkpoint_name)
return cls return cls
@require_torch @require_torch
...@@ -314,6 +315,27 @@ class LlamaIntegrationTest(unittest.TestCase): ...@@ -314,6 +315,27 @@ class LlamaIntegrationTest(unittest.TestCase):
}, },
) )
@slow
def test_conversion(self):
# This is excruciatingly slow since it has to recreate the entire merge
# list from the original vocabulary in spm
self.rust_tokenizer.save_pretrained("./out")
with tempfile.TemporaryDirectory() as dirname:
self.rust_tokenizer.save_pretrained(dirname)
with open(os.path.join(dirname, "tokenizer.json"), "r") as f:
old_serialized = f.read()
new_tokenizer = convert_slow_tokenizer(self.tokenizer)
with tempfile.NamedTemporaryFile() as f:
new_tokenizer.save(f.name)
# Re-opening since `f` is in bytes.
new_serialized = open(f.name, "r").read()
with open("out_tokenizer.json", "w") as g:
g.write(new_serialized)
self.assertEqual(old_serialized, new_serialized)
def test_simple_encode_decode(self): def test_simple_encode_decode(self):
pyth_tokenizer = self.tokenizer pyth_tokenizer = self.tokenizer
rust_tokenizer = self.rust_tokenizer rust_tokenizer = self.rust_tokenizer
...@@ -362,11 +384,27 @@ class LlamaIntegrationTest(unittest.TestCase): ...@@ -362,11 +384,27 @@ class LlamaIntegrationTest(unittest.TestCase):
self.assertEqual(pyth_tokenizer.encode(" Hello"), [1, 29871, 15043]) self.assertEqual(pyth_tokenizer.encode(" Hello"), [1, 29871, 15043])
self.assertEqual(rust_tokenizer.encode(" Hello"), [1, 29871, 15043]) self.assertEqual(rust_tokenizer.encode(" Hello"), [1, 29871, 15043])
def test_no_differences_showcase(self):
pyth_tokenizer = self.tokenizer
rust_tokenizer = self.rust_tokenizer
self.assertEqual(pyth_tokenizer.encode(""), [1])
self.assertEqual(rust_tokenizer.encode(""), [1])
self.assertEqual(pyth_tokenizer.encode(" "), [1, 259])
self.assertEqual(rust_tokenizer.encode(" "), [1, 259])
self.assertEqual(pyth_tokenizer.encode(" "), [1, 1678])
self.assertEqual(rust_tokenizer.encode(" "), [1, 1678])
self.assertEqual(pyth_tokenizer.encode(" Hello"), [1, 29871, 15043])
self.assertEqual(rust_tokenizer.encode(" Hello"), [1, 29871, 15043])
self.assertEqual(pyth_tokenizer.encode("<s>"), [1, 1]) self.assertEqual(pyth_tokenizer.encode("<s>"), [1, 1])
self.assertEqual(rust_tokenizer.encode("<s>"), [1, 1]) self.assertEqual(rust_tokenizer.encode("<s>"), [1, 1])
self.assertEqual(pyth_tokenizer.encode(""), [1]) def test_no_differences_decode(self):
self.assertEqual(rust_tokenizer.encode(""), [1]) pyth_tokenizer = self.tokenizer
rust_tokenizer = self.rust_tokenizer
self.assertEqual(pyth_tokenizer.decode([869]), ".") self.assertEqual(pyth_tokenizer.decode([869]), ".")
self.assertEqual(rust_tokenizer.decode([869]), ".") self.assertEqual(rust_tokenizer.decode([869]), ".")
...@@ -374,6 +412,15 @@ class LlamaIntegrationTest(unittest.TestCase): ...@@ -374,6 +412,15 @@ class LlamaIntegrationTest(unittest.TestCase):
self.assertEqual(pyth_tokenizer.decode([30112, 869]), "ا .") self.assertEqual(pyth_tokenizer.decode([30112, 869]), "ا .")
self.assertEqual(rust_tokenizer.decode([30112, 869]), "ا .") self.assertEqual(rust_tokenizer.decode([30112, 869]), "ا .")
def test_no_differences_special_tokens(self):
pyth_tokenizer = self.tokenizer
rust_tokenizer = self.rust_tokenizer
self.assertEqual(pyth_tokenizer.encode(""), [1])
self.assertEqual(rust_tokenizer.encode(""), [1])
self.assertEqual(pyth_tokenizer.encode("<s>"), [1, 1])
self.assertEqual(rust_tokenizer.encode("<s>"), [1, 1])
@unittest.skipIf( @unittest.skipIf(
os.getenv("RUN_TOKENIZER_INTEGRATION", "0") == "0", os.getenv("RUN_TOKENIZER_INTEGRATION", "0") == "0",
"RUN_TOKENIZER_INTEGRATION=1 to run tokenizer integration tests", "RUN_TOKENIZER_INTEGRATION=1 to run tokenizer integration tests",
...@@ -392,8 +439,8 @@ class LlamaIntegrationTest(unittest.TestCase): ...@@ -392,8 +439,8 @@ class LlamaIntegrationTest(unittest.TestCase):
self.assertEqual(encoded1, encoded2) self.assertEqual(encoded1, encoded2)
decoded1 = pyth_tokenizer.decode(encoded1) decoded1 = pyth_tokenizer.decode(encoded1, skip_special_tokens=True)
decoded2 = rust_tokenizer.decode(encoded2) decoded2 = rust_tokenizer.decode(encoded2, skip_special_tokens=True)
self.assertEqual(decoded1, decoded2) self.assertEqual(decoded1, decoded2)
...@@ -406,7 +453,7 @@ class LlamaIntegrationTest(unittest.TestCase): ...@@ -406,7 +453,7 @@ class LlamaIntegrationTest(unittest.TestCase):
self.assertEqual(encoded1, encoded2) self.assertEqual(encoded1, encoded2)
decoded1 = pyth_tokenizer.decode(encoded1) decoded1 = pyth_tokenizer.decode(encoded1, skip_special_tokens=True)
decoded2 = rust_tokenizer.decode(encoded2) decoded2 = rust_tokenizer.decode(encoded2, skip_special_tokens=True)
self.assertEqual(decoded1, decoded2) self.assertEqual(decoded1, decoded2)
...@@ -24,11 +24,10 @@ class ConvertSlowTokenizerTest(unittest.TestCase): ...@@ -24,11 +24,10 @@ class ConvertSlowTokenizerTest(unittest.TestCase):
original_tokenizer_with_bytefallback = FakeOriginalTokenizer(vocab_file=spm_model_file_with_bytefallback) original_tokenizer_with_bytefallback = FakeOriginalTokenizer(vocab_file=spm_model_file_with_bytefallback)
with warnings.catch_warnings(record=True) as w: with self.assertRaises(RuntimeError) as cm:
_ = SpmConverter(original_tokenizer_with_bytefallback) _ = SpmConverter(original_tokenizer_with_bytefallback)
self.assertEqual(len(w), 1)
self.assertIn( self.assertIn(
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option" "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
" which is not implemented in the fast tokenizers.", " which is not implemented in the fast tokenizers.",
str(w[0].message), str(cm.exception),
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment