Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a404e2c0
Unverified
Commit
a404e2c0
authored
Nov 06, 2025
by
Julien Denize
Committed by
GitHub
Nov 06, 2025
Browse files
Patch Mistral Tokenizer (#28146)
Signed-off-by:
Julien Denize
<
julien.denize@mistral.ai
>
parent
e31946f8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
22 deletions
+42
-22
tests/tokenization/test_mistral_tokenizer.py
tests/tokenization/test_mistral_tokenizer.py
+21
-8
vllm/transformers_utils/tokenizers/mistral.py
vllm/transformers_utils/tokenizers/mistral.py
+21
-14
No files found.
tests/tokenization/test_mistral_tokenizer.py
View file @
a404e2c0
...
@@ -334,20 +334,20 @@ class TestMistralTokenizer:
...
@@ -334,20 +334,20 @@ class TestMistralTokenizer:
def
test_encode
(
self
,
mistral_tokenizer
:
MistralTokenizer
):
def
test_encode
(
self
,
mistral_tokenizer
:
MistralTokenizer
):
token_ids
=
(
token_ids
=
(
[
1
,
22177
,
4304
,
2662
,
2
]
[
1
,
22177
,
4304
,
2662
]
if
mistral_tokenizer
.
is_tekken
if
mistral_tokenizer
.
is_tekken
else
[
1
,
23325
,
2294
,
1686
,
2
]
else
[
1
,
23325
,
2294
,
1686
]
)
)
assert
mistral_tokenizer
.
encode
(
"Hello world !"
)
==
token_ids
[:
-
1
]
assert
mistral_tokenizer
.
encode
(
"Hello world !"
)
==
token_ids
assert
mistral_tokenizer
.
encode
(
"Hello world !"
,
max_length
=
3
)
==
token_ids
[:
-
2
]
assert
mistral_tokenizer
.
encode
(
"Hello world !"
,
max_length
=
3
)
==
token_ids
[:
-
1
]
assert
(
assert
(
mistral_tokenizer
.
encode
(
"Hello world !"
,
truncation
=
True
,
max_length
=
3
)
mistral_tokenizer
.
encode
(
"Hello world !"
,
truncation
=
True
,
max_length
=
3
)
==
token_ids
[:
-
2
]
==
token_ids
[:
-
1
]
)
)
assert
(
assert
(
mistral_tokenizer
.
encode
(
"Hello world !"
,
truncation
=
False
,
max_length
=
3
)
mistral_tokenizer
.
encode
(
"Hello world !"
,
truncation
=
False
,
max_length
=
3
)
==
token_ids
[:
-
1
]
==
token_ids
)
)
assert
(
assert
(
...
@@ -358,7 +358,7 @@ class TestMistralTokenizer:
...
@@ -358,7 +358,7 @@ class TestMistralTokenizer:
mistral_tokenizer
.
encode
(
mistral_tokenizer
.
encode
(
"Hello world !"
,
add_special_tokens
=
True
,
max_length
=
3
"Hello world !"
,
add_special_tokens
=
True
,
max_length
=
3
)
)
==
token_ids
[:
-
2
]
==
token_ids
[:
-
1
]
)
)
assert
(
assert
(
mistral_tokenizer
.
encode
(
mistral_tokenizer
.
encode
(
...
@@ -368,7 +368,7 @@ class TestMistralTokenizer:
...
@@ -368,7 +368,7 @@ class TestMistralTokenizer:
)
)
assert
(
assert
(
mistral_tokenizer
.
encode
(
"Hello world !"
,
add_special_tokens
=
False
)
mistral_tokenizer
.
encode
(
"Hello world !"
,
add_special_tokens
=
False
)
==
token_ids
[
1
:
-
1
]
==
token_ids
[
1
:]
)
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -1088,6 +1088,19 @@ class TestMistralTokenizer:
...
@@ -1088,6 +1088,19 @@ class TestMistralTokenizer:
==
expected_tokens
[
mistral_tokenizer
.
is_tekken
]
==
expected_tokens
[
mistral_tokenizer
.
is_tekken
]
)
)
def
test_decode_int
(
self
,
mistral_tokenizer
:
MistralTokenizer
,
):
ids
=
1
assert
(
mistral_tokenizer
.
decode
(
ids
,
skip_special_tokens
=
False
,
)
==
"<s>"
)
def
test_convert_tokens_to_string
(
self
,
mistral_tokenizer
:
MistralTokenizer
):
def
test_convert_tokens_to_string
(
self
,
mistral_tokenizer
:
MistralTokenizer
):
tokens
=
(
tokens
=
(
[
[
...
...
vllm/transformers_utils/tokenizers/mistral.py
View file @
a404e2c0
...
@@ -165,6 +165,7 @@ def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
...
@@ -165,6 +165,7 @@ def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
class
MistralTokenizer
(
TokenizerBase
):
class
MistralTokenizer
(
TokenizerBase
):
def
__init__
(
self
,
tokenizer
:
"TransformersMistralTokenizer"
)
->
None
:
def
__init__
(
self
,
tokenizer
:
"TransformersMistralTokenizer"
)
->
None
:
from
mistral_common.protocol.instruct.validator
import
ValidationMode
from
mistral_common.tokens.tokenizers.sentencepiece
import
(
from
mistral_common.tokens.tokenizers.sentencepiece
import
(
SentencePieceTokenizer
,
SentencePieceTokenizer
,
)
)
...
@@ -175,6 +176,14 @@ class MistralTokenizer(TokenizerBase):
...
@@ -175,6 +176,14 @@ class MistralTokenizer(TokenizerBase):
self
.
instruct
=
self
.
mistral
.
instruct_tokenizer
self
.
instruct
=
self
.
mistral
.
instruct_tokenizer
self
.
tokenizer
=
self
.
instruct
.
tokenizer
self
.
tokenizer
=
self
.
instruct
.
tokenizer
mode
=
self
.
mistral
.
_chat_completion_request_validator
.
_mode
if
mode
!=
ValidationMode
.
test
:
raise
ValueError
(
"Mistral tokenizer must be in test mode. Make sure to "
"set `mode='ValidationMode.test'` when creating the "
"Mistral tokenizer."
)
_mistral_version_str
=
str
(
self
.
tokenizer
.
version
.
value
)
_mistral_version_str
=
str
(
self
.
tokenizer
.
version
.
value
)
self
.
version
:
int
=
int
(
_mistral_version_str
.
split
(
"v"
)[
-
1
])
self
.
version
:
int
=
int
(
_mistral_version_str
.
split
(
"v"
)[
-
1
])
...
@@ -205,6 +214,7 @@ class MistralTokenizer(TokenizerBase):
...
@@ -205,6 +214,7 @@ class MistralTokenizer(TokenizerBase):
def
from_pretrained
(
def
from_pretrained
(
cls
,
path_or_repo_id
:
str
,
*
,
revision
:
str
|
None
=
None
cls
,
path_or_repo_id
:
str
,
*
,
revision
:
str
|
None
=
None
)
->
"MistralTokenizer"
:
)
->
"MistralTokenizer"
:
from
mistral_common.protocol.instruct.validator
import
ValidationMode
from
transformers.tokenization_mistral_common
import
(
from
transformers.tokenization_mistral_common
import
(
MistralCommonTokenizer
as
TransformersMistralTokenizer
,
MistralCommonTokenizer
as
TransformersMistralTokenizer
,
)
)
...
@@ -212,7 +222,7 @@ class MistralTokenizer(TokenizerBase):
...
@@ -212,7 +222,7 @@ class MistralTokenizer(TokenizerBase):
str_revision
=
"main"
if
revision
is
None
else
revision
str_revision
=
"main"
if
revision
is
None
else
revision
return
cls
(
return
cls
(
TransformersMistralTokenizer
.
from_pretrained
(
TransformersMistralTokenizer
.
from_pretrained
(
path_or_repo_id
,
revision
=
str_revision
path_or_repo_id
,
revision
=
str_revision
,
mode
=
ValidationMode
.
test
)
)
)
)
...
@@ -339,15 +349,9 @@ class MistralTokenizer(TokenizerBase):
...
@@ -339,15 +349,9 @@ class MistralTokenizer(TokenizerBase):
max_length
:
int
|
None
=
None
,
max_length
:
int
|
None
=
None
,
add_special_tokens
:
bool
|
None
=
None
,
add_special_tokens
:
bool
|
None
=
None
,
)
->
list
[
int
]:
)
->
list
[
int
]:
if
add_special_tokens
is
not
None
:
encoded
=
self
.
tokenizer
.
encode
(
return
self
.
transformers_tokenizer
.
encode
(
text
,
bos
=
add_special_tokens
is
not
False
,
eos
=
False
text
,
truncation
=
truncation
,
max_length
=
max_length
,
add_special_tokens
=
add_special_tokens
,
)
)
else
:
encoded
=
self
.
tokenizer
.
encode
(
text
,
bos
=
True
,
eos
=
False
)
if
truncation
is
not
False
and
max_length
is
not
None
:
if
truncation
is
not
False
and
max_length
is
not
None
:
return
encoded
[:
max_length
]
return
encoded
[:
max_length
]
...
@@ -383,6 +387,9 @@ class MistralTokenizer(TokenizerBase):
...
@@ -383,6 +387,9 @@ class MistralTokenizer(TokenizerBase):
)
)
def
decode
(
self
,
ids
:
list
[
int
]
|
int
,
skip_special_tokens
:
bool
=
True
)
->
str
:
def
decode
(
self
,
ids
:
list
[
int
]
|
int
,
skip_special_tokens
:
bool
=
True
)
->
str
:
if
isinstance
(
ids
,
int
):
ids
=
[
ids
]
return
self
.
transformers_tokenizer
.
decode
(
return
self
.
transformers_tokenizer
.
decode
(
ids
,
skip_special_tokens
=
skip_special_tokens
ids
,
skip_special_tokens
=
skip_special_tokens
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment