"...kernels/git@developer.sourcefind.cn:change/sglang.git" did not exist on "515ef4facbc89cd7c093c198386a8817fce856d6"
Unverified Commit 117dba99 authored by Théo Matussière's avatar Théo Matussière Committed by GitHub
Browse files

fix backend tokenizer args override: key mismatch (#10686)



* fix backend tokenizer args override: key mismatch

* no touching the docs

* fix mpnet

* add mpnet to test

* fix test
Co-authored-by: default avatartheo <theo@matussie.re>
parent 427ea3fe
...@@ -190,11 +190,11 @@ class BertTokenizerFast(PreTrainedTokenizerFast): ...@@ -190,11 +190,11 @@ class BertTokenizerFast(PreTrainedTokenizerFast):
pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__()) pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
if ( if (
pre_tok_state.get("do_lower_case", do_lower_case) != do_lower_case pre_tok_state.get("lowercase", do_lower_case) != do_lower_case
or pre_tok_state.get("strip_accents", strip_accents) != strip_accents or pre_tok_state.get("strip_accents", strip_accents) != strip_accents
): ):
pre_tok_class = getattr(normalizers, pre_tok_state.pop("type")) pre_tok_class = getattr(normalizers, pre_tok_state.pop("type"))
pre_tok_state["do_lower_case"] = do_lower_case pre_tok_state["lowercase"] = do_lower_case
pre_tok_state["strip_accents"] = strip_accents pre_tok_state["strip_accents"] = strip_accents
self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state) self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state)
......
...@@ -138,11 +138,11 @@ class MPNetTokenizerFast(PreTrainedTokenizerFast): ...@@ -138,11 +138,11 @@ class MPNetTokenizerFast(PreTrainedTokenizerFast):
pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__()) pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
if ( if (
pre_tok_state.get("do_lower_case", do_lower_case) != do_lower_case pre_tok_state.get("lowercase", do_lower_case) != do_lower_case
or pre_tok_state.get("strip_accents", strip_accents) != strip_accents or pre_tok_state.get("strip_accents", strip_accents) != strip_accents
): ):
pre_tok_class = getattr(normalizers, pre_tok_state.pop("type")) pre_tok_class = getattr(normalizers, pre_tok_state.pop("type"))
pre_tok_state["do_lower_case"] = do_lower_case pre_tok_state["lowercase"] = do_lower_case
pre_tok_state["strip_accents"] = strip_accents pre_tok_state["strip_accents"] = strip_accents
self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state) self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state)
......
...@@ -110,3 +110,14 @@ class AutoTokenizerTest(unittest.TestCase): ...@@ -110,3 +110,14 @@ class AutoTokenizerTest(unittest.TestCase):
def test_from_pretrained_use_fast_toggle(self): def test_from_pretrained_use_fast_toggle(self):
self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased", use_fast=False), BertTokenizer) self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased", use_fast=False), BertTokenizer)
self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased"), BertTokenizerFast) self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased"), BertTokenizerFast)
@require_tokenizers
def test_do_lower_case(self):
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased", do_lower_case=False)
sample = "Hello, world. How are you?"
tokens = tokenizer.tokenize(sample)
self.assertEqual("[UNK]", tokens[0])
tokenizer = AutoTokenizer.from_pretrained("microsoft/mpnet-base", do_lower_case=False)
tokens = tokenizer.tokenize(sample)
self.assertEqual("[UNK]", tokens[0])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment