Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b4e4eda9
Unverified
Commit
b4e4eda9
authored
Sep 20, 2024
by
Patrick von Platen
Committed by
GitHub
Sep 20, 2024
Browse files
[Bugfix][Core] Fix tekken edge case for mistral tokenizer (#8640)
parent
2874bac6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
4 deletions
+54
-4
tests/models/decoder_only/language/test_mistral.py
tests/models/decoder_only/language/test_mistral.py
+25
-1
vllm/transformers_utils/tokenizers/mistral.py
vllm/transformers_utils/tokenizers/mistral.py
+29
-3
No files found.
tests/models/decoder_only/language/test_mistral.py
View file @
b4e4eda9
...
@@ -4,7 +4,7 @@ Run `pytest tests/models/test_mistral.py`.
...
@@ -4,7 +4,7 @@ Run `pytest tests/models/test_mistral.py`.
"""
"""
import
pytest
import
pytest
from
vllm
import
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
...utils
import
check_logprobs_close
from
...utils
import
check_logprobs_close
...
@@ -16,6 +16,10 @@ MODELS = [
...
@@ -16,6 +16,10 @@ MODELS = [
]
]
SAMPLING_PARAMS
=
SamplingParams
(
max_tokens
=
512
,
temperature
=
0.0
,
logprobs
=
5
)
SAMPLING_PARAMS
=
SamplingParams
(
max_tokens
=
512
,
temperature
=
0.0
,
logprobs
=
5
)
SYMBOLIC_LANG_PROMPTS
=
[
"勇敢な船乗りについての詩を書く"
,
# japanese
"寫一首關於勇敢的水手的詩"
,
# chinese
]
# for function calling
# for function calling
TOOLS
=
[{
TOOLS
=
[{
...
@@ -131,6 +135,26 @@ def test_mistral_format(
...
@@ -131,6 +135,26 @@ def test_mistral_format(
)
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
[
1
:])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"prompt"
,
SYMBOLIC_LANG_PROMPTS
)
def
test_mistral_symbolic_languages
(
model
:
str
,
dtype
:
str
,
prompt
:
str
,
)
->
None
:
prompt
=
"hi"
msg
=
{
"role"
:
"user"
,
"content"
:
prompt
}
llm
=
LLM
(
model
=
model
,
dtype
=
dtype
,
max_model_len
=
8192
,
tokenizer_mode
=
"mistral"
,
config_format
=
"mistral"
,
load_format
=
"mistral"
)
outputs
=
llm
.
chat
([
msg
],
sampling_params
=
SAMPLING_PARAMS
)
assert
"�"
not
in
outputs
[
0
].
outputs
[
0
].
text
.
strip
()
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
[
1
:])
# v1 can't do func calling
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
[
1
:])
# v1 can't do func calling
def
test_mistral_function_calling
(
def
test_mistral_function_calling
(
...
...
vllm/transformers_utils/tokenizers/mistral.py
View file @
b4e4eda9
...
@@ -175,10 +175,29 @@ class MistralTokenizer:
...
@@ -175,10 +175,29 @@ class MistralTokenizer:
def
convert_tokens_to_string
(
self
,
tokens
:
List
[
str
])
->
str
:
def
convert_tokens_to_string
(
self
,
tokens
:
List
[
str
])
->
str
:
if
isinstance
(
self
.
tokenizer
,
Tekkenizer
):
if
isinstance
(
self
.
tokenizer
,
Tekkenizer
):
return
""
.
join
(
t
for
t
in
tokens
tokens
=
[
if
t
not
in
self
.
tokenizer
.
_all_special_tokens
)
t
for
t
in
tokens
if
t
not
in
self
.
tokenizer
.
_all_special_tokens
]
if
any
(
isinstance
(
t
,
bytes
)
for
t
in
tokens
):
# we need to encode and decode all tokens again
shift
=
self
.
tokenizer
.
num_special_tokens
byte_tokens
=
[
t
.
encode
(
"utf-8"
)
if
not
isinstance
(
t
,
bytes
)
else
t
for
t
in
tokens
]
ids
=
[
self
.
tokenizer
.
_tekken_token2id_nospecial
[
t
]
+
shift
for
t
in
byte_tokens
]
decoded
=
self
.
tokenizer
.
decode
(
ids
)
else
:
decoded
=
""
.
join
(
tokens
)
else
:
else
:
return
self
.
tokenizer
.
decode
(
tokens
)
# type: ignore[arg-type]
decoded
=
self
.
tokenizer
.
decode
(
tokens
)
# type: ignore[arg-type]
return
decoded
def
decode
(
self
,
ids
:
Union
[
List
[
int
],
int
])
->
str
:
def
decode
(
self
,
ids
:
Union
[
List
[
int
],
int
])
->
str
:
if
isinstance
(
ids
,
int
):
if
isinstance
(
ids
,
int
):
...
@@ -200,4 +219,11 @@ class MistralTokenizer:
...
@@ -200,4 +219,11 @@ class MistralTokenizer:
self
.
tokenizer
)
self
.
tokenizer
)
tokens
=
[
self
.
tokenizer
.
id_to_piece
(
id
)
for
id
in
ids
]
tokens
=
[
self
.
tokenizer
.
id_to_piece
(
id
)
for
id
in
ids
]
if
any
(
t
.
strip
()
==
"�"
for
t
in
tokens
):
# if any stripped decoded token is undefined
# because it's invalid unicode then pass bytes
# See: https://github.com/vllm-project/vllm/pull/8640
tokens
=
[
self
.
tokenizer
.
id_to_byte_piece
(
id
)
for
id
in
ids
]
return
tokens
return
tokens
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment