Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3ee696a6
Unverified
Commit
3ee696a6
authored
Feb 11, 2025
by
Keyun Tong
Committed by
GitHub
Feb 12, 2025
Browse files
[RFC][vllm-API] Support tokenizer registry for customized tokenizer in vLLM (#12518)
Signed-off-by:
Keyun Tong
<
tongkeyun@gmail.com
>
parent
72c2b68d
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
343 additions
and
41 deletions
+343
-41
benchmarks/benchmark_serving.py
benchmarks/benchmark_serving.py
+3
-2
tests/tokenization/test_tokenizer_registry.py
tests/tokenization/test_tokenizer_registry.py
+123
-0
vllm/config.py
vllm/config.py
+5
-4
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+4
-2
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+19
-12
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+1
-2
vllm/entrypoints/openai/serving_score.py
vllm/entrypoints/openai/serving_score.py
+1
-1
vllm/logits_process.py
vllm/logits_process.py
+1
-1
vllm/transformers_utils/tokenizer.py
vllm/transformers_utils/tokenizer.py
+12
-6
vllm/transformers_utils/tokenizer_base.py
vllm/transformers_utils/tokenizer_base.py
+146
-0
vllm/transformers_utils/tokenizers/mistral.py
vllm/transformers_utils/tokenizers/mistral.py
+28
-11
No files found.
benchmarks/benchmark_serving.py
View file @
3ee696a6
...
...
@@ -1275,11 +1275,12 @@ if __name__ == "__main__":
'--tokenizer-mode'
,
type
=
str
,
default
=
"auto"
,
choices
=
[
'auto'
,
'slow'
,
'mistral'
],
choices
=
[
'auto'
,
'slow'
,
'mistral'
,
'custom'
],
help
=
'The tokenizer mode.
\n\n
* "auto" will use the '
'fast tokenizer if available.
\n
* "slow" will '
'always use the slow tokenizer.
\n
* '
'"mistral" will always use the `mistral_common` tokenizer.'
)
'"mistral" will always use the `mistral_common` tokenizer.
\n
*'
'"custom" will use --tokenizer to select the preregistered tokenizer.'
)
parser
.
add_argument
(
"--served-model-name"
,
type
=
str
,
...
...
tests/tokenization/test_tokenizer_registry.py
0 → 100644
View file @
3ee696a6
# SPDX-License-Identifier: Apache-2.0
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.transformers_utils.tokenizer_base
import
(
TokenizerBase
,
TokenizerRegistry
)
if
TYPE_CHECKING
:
from
vllm.entrypoints.chat_utils
import
ChatCompletionMessageParam
class
TestTokenizer
(
TokenizerBase
):
@
classmethod
def
from_pretrained
(
cls
,
*
args
,
**
kwargs
)
->
"TestTokenizer"
:
return
TestTokenizer
()
@
property
def
all_special_tokens_extended
(
self
)
->
List
[
str
]:
raise
NotImplementedError
()
@
property
def
all_special_tokens
(
self
)
->
List
[
str
]:
raise
NotImplementedError
()
@
property
def
all_special_ids
(
self
)
->
List
[
int
]:
raise
NotImplementedError
()
@
property
def
bos_token_id
(
self
)
->
int
:
return
0
@
property
def
eos_token_id
(
self
)
->
int
:
return
1
@
property
def
sep_token
(
self
)
->
str
:
raise
NotImplementedError
()
@
property
def
pad_token
(
self
)
->
str
:
raise
NotImplementedError
()
@
property
def
is_fast
(
self
)
->
bool
:
raise
NotImplementedError
()
@
property
def
vocab_size
(
self
)
->
int
:
raise
NotImplementedError
()
@
property
def
max_token_id
(
self
)
->
int
:
raise
NotImplementedError
()
def
__call__
(
self
,
text
:
Union
[
str
,
List
[
str
],
List
[
int
]],
text_pair
:
Optional
[
str
]
=
None
,
add_special_tokens
:
bool
=
False
,
truncation
:
bool
=
False
,
max_length
:
Optional
[
int
]
=
None
,
):
raise
NotImplementedError
()
def
get_vocab
(
self
)
->
Dict
[
str
,
int
]:
raise
NotImplementedError
()
def
get_added_vocab
(
self
)
->
Dict
[
str
,
int
]:
raise
NotImplementedError
()
def
encode_one
(
self
,
text
:
str
,
truncation
:
bool
=
False
,
max_length
:
Optional
[
int
]
=
None
,
)
->
List
[
int
]:
raise
NotImplementedError
()
def
encode
(
self
,
text
:
str
,
add_special_tokens
:
Optional
[
bool
]
=
None
)
->
List
[
int
]:
raise
NotImplementedError
()
def
apply_chat_template
(
self
,
messages
:
List
[
"ChatCompletionMessageParam"
],
tools
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
,
**
kwargs
)
->
List
[
int
]:
raise
NotImplementedError
()
def
convert_tokens_to_string
(
self
,
tokens
:
List
[
str
])
->
str
:
raise
NotImplementedError
()
def
decode
(
self
,
ids
:
Union
[
List
[
int
],
int
],
skip_special_tokens
:
bool
=
True
)
->
str
:
raise
NotImplementedError
()
def
convert_ids_to_tokens
(
self
,
ids
:
List
[
int
],
skip_special_tokens
:
bool
=
True
,
)
->
List
[
str
]:
raise
NotImplementedError
()
def
test_customized_tokenizer
():
TokenizerRegistry
.
register
(
"test_tokenizer"
,
"tests.tokenization.test_tokenizer_registry"
,
"TestTokenizer"
)
tokenizer
=
TokenizerRegistry
.
get_tokenizer
(
"test_tokenizer"
)
assert
isinstance
(
tokenizer
,
TestTokenizer
)
assert
tokenizer
.
bos_token_id
==
0
assert
tokenizer
.
eos_token_id
==
1
tokenizer
=
get_tokenizer
(
"test_tokenizer"
,
tokenizer_mode
=
"custom"
)
assert
isinstance
(
tokenizer
,
TestTokenizer
)
assert
tokenizer
.
bos_token_id
==
0
assert
tokenizer
.
eos_token_id
==
1
vllm/config.py
View file @
3ee696a6
...
...
@@ -102,8 +102,9 @@ class ModelConfig:
it; otherwise, you must specify explicitly which task to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, "slow" will always use the slow tokenizer, and
"mistral" will always use the tokenizer from `mistral_common`.
available, "slow" will always use the slow tokenizer,
"mistral" will always use the tokenizer from `mistral_common`, and
"custom" will use --tokenizer to select the preregistered tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
allowed_local_media_path: Allowing API requests to read local images or
...
...
@@ -467,10 +468,10 @@ class ModelConfig:
def
_verify_tokenizer_mode
(
self
)
->
None
:
tokenizer_mode
=
self
.
tokenizer_mode
.
lower
()
if
tokenizer_mode
not
in
[
"auto"
,
"slow"
,
"mistral"
]:
if
tokenizer_mode
not
in
[
"auto"
,
"slow"
,
"mistral"
,
"custom"
]:
raise
ValueError
(
f
"Unknown tokenizer mode:
{
self
.
tokenizer_mode
}
. Must be "
"either 'auto', 'slow'
or
'mistral'."
)
"either 'auto', 'slow'
,
'mistral'
or 'custom'
."
)
self
.
tokenizer_mode
=
tokenizer_mode
def
_get_preferred_task
(
...
...
vllm/engine/arg_utils.py
View file @
3ee696a6
...
...
@@ -284,11 +284,13 @@ class EngineArgs:
'--tokenizer-mode'
,
type
=
str
,
default
=
EngineArgs
.
tokenizer_mode
,
choices
=
[
'auto'
,
'slow'
,
'mistral'
],
choices
=
[
'auto'
,
'slow'
,
'mistral'
,
'custom'
],
help
=
'The tokenizer mode.
\n\n
* "auto" will use the '
'fast tokenizer if available.
\n
* "slow" will '
'always use the slow tokenizer.
\n
* '
'"mistral" will always use the `mistral_common` tokenizer.'
)
'"mistral" will always use the `mistral_common` tokenizer.
\n
* '
'"custom" will use --tokenizer to select the '
'preregistered tokenizer.'
)
parser
.
add_argument
(
'--trust-remote-code'
,
action
=
'store_true'
,
help
=
'Trust remote code from huggingface.'
)
...
...
vllm/entrypoints/llm.py
View file @
3ee696a6
...
...
@@ -1051,9 +1051,9 @@ class LLM:
def
_cross_encoding_score
(
self
,
tokenizer
:
Union
[
AnyTokenizer
]
,
text_1
:
List
[
Union
[
str
,
TextPrompt
,
TokensPrompt
]
],
text_2
:
List
[
Union
[
str
,
TextPrompt
,
TokensPrompt
]
],
tokenizer
:
AnyTokenizer
,
text_1
:
List
[
str
],
text_2
:
List
[
str
],
truncate_prompt_tokens
:
Optional
[
int
]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
...
...
@@ -1176,29 +1176,36 @@ class LLM:
if
isinstance
(
text_1
,
(
str
,
dict
)):
# Convert a single prompt to a list.
text_1
=
[
text_1
]
text_1
=
[
ensure_str
(
t
)
for
t
in
text_1
]
input_text_1
:
List
[
str
]
=
[
ensure_str
(
t
)
for
t
in
text_1
]
if
isinstance
(
text_2
,
(
str
,
dict
)):
# Convert a single prompt to a list.
text_2
=
[
text_2
]
text_2
=
[
ensure_str
(
t
)
for
t
in
text_2
]
input_text_2
:
List
[
str
]
=
[
ensure_str
(
t
)
for
t
in
text_2
]
if
len
(
text_1
)
>
1
and
len
(
text_1
)
!=
len
(
text_2
):
if
len
(
input_
text_1
)
>
1
and
len
(
input_
text_1
)
!=
len
(
input_
text_2
):
raise
ValueError
(
"Input lengths must be either 1:1, 1:N or N:N"
)
if
len
(
text_1
)
==
0
:
if
len
(
input_
text_1
)
==
0
:
raise
ValueError
(
"At least one text element must be given"
)
if
len
(
text_2
)
==
0
:
if
len
(
input_
text_2
)
==
0
:
raise
ValueError
(
"At least one text_pair element must be given"
)
if
self
.
llm_engine
.
model_config
.
is_cross_encoder
:
return
self
.
_cross_encoding_score
(
tokenizer
,
text_1
,
text_2
,
return
self
.
_cross_encoding_score
(
tokenizer
,
input_text_1
,
input_text_2
,
truncate_prompt_tokens
,
use_tqdm
,
lora_request
,
prompt_adapter_request
)
else
:
return
self
.
_embedding_score
(
tokenizer
,
text_1
,
text_2
,
truncate_prompt_tokens
,
use_tqdm
,
lora_request
,
prompt_adapter_request
)
return
self
.
_embedding_score
(
tokenizer
,
input_text_1
,
# type: ignore[arg-type]
input_text_2
,
# type: ignore[arg-type]
truncate_prompt_tokens
,
use_tqdm
,
lora_request
,
prompt_adapter_request
)
def
start_profile
(
self
)
->
None
:
self
.
llm_engine
.
start_profile
()
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
3ee696a6
...
...
@@ -400,8 +400,7 @@ class OpenAIServing:
_chat_template_kwargs
.
update
(
chat_template_kwargs
or
{})
request_prompt
:
Union
[
str
,
List
[
int
]]
is_mistral_tokenizer
=
isinstance
(
tokenizer
,
MistralTokenizer
)
if
is_mistral_tokenizer
:
if
isinstance
(
tokenizer
,
MistralTokenizer
):
request_prompt
=
apply_mistral_chat_template
(
tokenizer
,
messages
=
messages
,
...
...
vllm/entrypoints/openai/serving_score.py
View file @
3ee696a6
...
...
@@ -121,7 +121,7 @@ class OpenAIServingScores(OpenAIServing):
tokenize_async
=
make_async
(
tokenizer
.
__call__
,
executor
=
self
.
_tokenizer_executor
)
prompt_inputs
=
await
tokenize_async
(
text
=
q
,
prompt_inputs
=
await
tokenize_async
(
q
,
text_pair
=
t
,
**
tokenization_kwargs
)
...
...
vllm/logits_process.py
View file @
3ee696a6
...
...
@@ -31,7 +31,7 @@ def get_bad_words_logits_processors(
if
isinstance
(
tokenizer
,
MistralTokenizer
):
# Mistral tokenizers should not add special tokens
prompt_token_ids
=
tokenizer
.
encode
(
promp
t
=
prompt
)
prompt_token_ids
=
tokenizer
.
encode
(
tex
t
=
prompt
)
else
:
prompt_token_ids
=
tokenizer
.
encode
(
text
=
prompt
,
add_special_tokens
=
False
)
...
...
vllm/transformers_utils/tokenizer.py
View file @
3ee696a6
...
...
@@ -14,6 +14,8 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.transformers_utils.tokenizer_base
import
(
TokenizerBase
,
TokenizerRegistry
)
from
vllm.transformers_utils.tokenizers
import
MistralTokenizer
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.utils
import
make_async
...
...
@@ -21,7 +23,7 @@ from vllm.utils import make_async
logger
=
init_logger
(
__name__
)
AnyTokenizer
=
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
,
Mistral
Tokenizer
]
Tokenizer
Base
]
def
decode_tokens
(
...
...
@@ -47,11 +49,7 @@ def encode_tokens(
Backend-agnostic equivalent of HF's
:code:`tokenizer.encode(text, add_special_tokens=...)`.
"""
if
isinstance
(
tokenizer
,
MistralTokenizer
):
return
tokenizer
.
tokenizer
.
encode
(
text
,
bos
=
add_special_tokens
,
eos
=
add_special_tokens
)
elif
add_special_tokens
is
not
None
:
if
add_special_tokens
is
not
None
:
return
tokenizer
.
encode
(
text
,
add_special_tokens
=
add_special_tokens
)
return
tokenizer
.
encode
(
text
)
...
...
@@ -183,9 +181,17 @@ def get_tokenizer(
'encoding and decoding.'
,
FutureWarning
,
stacklevel
=
2
)
tokenizer
:
AnyTokenizer
if
tokenizer_mode
==
"mistral"
:
tokenizer
=
MistralTokenizer
.
from_pretrained
(
str
(
tokenizer_name
),
revision
=
revision
)
elif
tokenizer_mode
==
"custom"
:
tokenizer
=
TokenizerRegistry
.
get_tokenizer
(
str
(
tokenizer_name
),
*
args
,
revision
=
revision
,
download_dir
=
download_dir
,
**
kwargs
)
else
:
try
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
...
...
vllm/transformers_utils/tokenizer_base.py
0 → 100644
View file @
3ee696a6
# SPDX-License-Identifier: Apache-2.0
import
importlib
from
abc
import
ABC
,
abstractmethod
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
if
TYPE_CHECKING
:
from
vllm.entrypoints.chat_utils
import
ChatCompletionMessageParam
class
TokenizerBase
(
ABC
):
@
property
@
abstractmethod
def
all_special_tokens_extended
(
self
)
->
List
[
str
]:
raise
NotImplementedError
()
@
property
@
abstractmethod
def
all_special_tokens
(
self
)
->
List
[
str
]:
raise
NotImplementedError
()
@
property
@
abstractmethod
def
all_special_ids
(
self
)
->
List
[
int
]:
raise
NotImplementedError
()
@
property
@
abstractmethod
def
bos_token_id
(
self
)
->
int
:
raise
NotImplementedError
()
@
property
@
abstractmethod
def
eos_token_id
(
self
)
->
int
:
raise
NotImplementedError
()
@
property
@
abstractmethod
def
sep_token
(
self
)
->
str
:
raise
NotImplementedError
()
@
property
@
abstractmethod
def
pad_token
(
self
)
->
str
:
raise
NotImplementedError
()
@
property
@
abstractmethod
def
is_fast
(
self
)
->
bool
:
raise
NotImplementedError
()
@
property
@
abstractmethod
def
vocab_size
(
self
)
->
int
:
raise
NotImplementedError
()
@
property
@
abstractmethod
def
max_token_id
(
self
)
->
int
:
raise
NotImplementedError
()
def
__len__
(
self
)
->
int
:
return
self
.
vocab_size
@
abstractmethod
def
__call__
(
self
,
text
:
Union
[
str
,
List
[
str
],
List
[
int
]],
text_pair
:
Optional
[
str
]
=
None
,
add_special_tokens
:
bool
=
False
,
truncation
:
bool
=
False
,
max_length
:
Optional
[
int
]
=
None
,
):
raise
NotImplementedError
()
@
abstractmethod
def
get_vocab
(
self
)
->
Dict
[
str
,
int
]:
raise
NotImplementedError
()
@
abstractmethod
def
get_added_vocab
(
self
)
->
Dict
[
str
,
int
]:
raise
NotImplementedError
()
@
abstractmethod
def
encode_one
(
self
,
text
:
str
,
truncation
:
bool
=
False
,
max_length
:
Optional
[
int
]
=
None
,
)
->
List
[
int
]:
raise
NotImplementedError
()
@
abstractmethod
def
encode
(
self
,
text
:
str
,
add_special_tokens
:
Optional
[
bool
]
=
None
)
->
List
[
int
]:
raise
NotImplementedError
()
@
abstractmethod
def
apply_chat_template
(
self
,
messages
:
List
[
"ChatCompletionMessageParam"
],
tools
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
,
**
kwargs
)
->
List
[
int
]:
raise
NotImplementedError
()
@
abstractmethod
def
convert_tokens_to_string
(
self
,
tokens
:
List
[
str
])
->
str
:
raise
NotImplementedError
()
@
abstractmethod
def
decode
(
self
,
ids
:
Union
[
List
[
int
],
int
],
skip_special_tokens
:
bool
=
True
)
->
str
:
raise
NotImplementedError
()
@
abstractmethod
def
convert_ids_to_tokens
(
self
,
ids
:
List
[
int
],
skip_special_tokens
:
bool
=
True
,
)
->
List
[
str
]:
raise
NotImplementedError
()
class
TokenizerRegistry
:
# Tokenizer name -> (tokenizer module, tokenizer class)
REGISTRY
:
Dict
[
str
,
Tuple
[
str
,
str
]]
=
{}
@
staticmethod
def
register
(
name
:
str
,
module
:
str
,
class_name
:
str
)
->
None
:
TokenizerRegistry
.
REGISTRY
[
name
]
=
(
module
,
class_name
)
@
staticmethod
def
get_tokenizer
(
tokenizer_name
:
str
,
*
args
,
**
kwargs
,
)
->
TokenizerBase
:
tokenizer_cls
=
TokenizerRegistry
.
REGISTRY
.
get
(
tokenizer_name
)
if
tokenizer_cls
is
None
:
raise
ValueError
(
f
"Tokenizer
{
tokenizer_name
}
not found."
)
tokenizer_module
=
importlib
.
import_module
(
tokenizer_cls
[
0
])
class_
=
getattr
(
tokenizer_module
,
tokenizer_cls
[
1
])
return
class_
.
from_pretrained
(
*
args
,
**
kwargs
)
vllm/transformers_utils/tokenizers/mistral.py
View file @
3ee696a6
...
...
@@ -10,6 +10,7 @@ import huggingface_hub
from
huggingface_hub
import
HfApi
,
hf_hub_download
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer_base
import
TokenizerBase
from
vllm.utils
import
is_list_of
if
TYPE_CHECKING
:
...
...
@@ -140,7 +141,7 @@ def make_mistral_chat_completion_request(
tools
=
tools
)
# type: ignore[type-var]
class
MistralTokenizer
:
class
MistralTokenizer
(
TokenizerBase
)
:
def
__init__
(
self
,
tokenizer
:
"PublicMistralTokenizer"
)
->
None
:
self
.
mistral
=
tokenizer
...
...
@@ -251,6 +252,14 @@ class MistralTokenizer:
def
eos_token_id
(
self
)
->
int
:
return
self
.
tokenizer
.
eos_id
@
property
def
sep_token
(
self
)
->
str
:
raise
NotImplementedError
()
@
property
def
pad_token
(
self
)
->
str
:
raise
NotImplementedError
()
@
property
def
is_fast
(
self
)
->
bool
:
return
True
...
...
@@ -268,25 +277,26 @@ class MistralTokenizer:
def
__call__
(
self
,
prompt
:
Union
[
str
,
List
[
str
],
List
[
int
]],
text
:
Union
[
str
,
List
[
str
],
List
[
int
]],
text_pair
:
Optional
[
str
]
=
None
,
add_special_tokens
:
bool
=
False
,
truncation
:
bool
=
False
,
max_length
:
Optional
[
int
]
=
None
,
):
input_ids
:
Union
[
List
[
int
],
List
[
List
[
int
]]]
# For List[str], original prompt text
if
is_list_of
(
promp
t
,
str
):
if
is_list_of
(
tex
t
,
str
):
input_ids_
:
List
[
List
[
int
]]
=
[]
for
p
in
promp
t
:
for
p
in
tex
t
:
each_input_ids
=
self
.
encode_one
(
p
,
truncation
,
max_length
)
input_ids_
.
append
(
each_input_ids
)
input_ids
=
input_ids_
# For List[int], apply chat template output, already tokens.
elif
is_list_of
(
promp
t
,
int
):
input_ids
=
promp
t
elif
is_list_of
(
tex
t
,
int
):
input_ids
=
tex
t
# For str, single prompt text
else
:
input_ids
=
self
.
encode_one
(
promp
t
,
truncation
,
max_length
)
input_ids
=
self
.
encode_one
(
tex
t
,
truncation
,
max_length
)
return
Encoding
(
input_ids
=
input_ids
)
def
get_vocab
(
self
)
->
Dict
[
str
,
int
]:
...
...
@@ -300,22 +310,29 @@ class MistralTokenizer:
def
encode_one
(
self
,
promp
t
:
str
,
tex
t
:
str
,
truncation
:
bool
=
False
,
max_length
:
Optional
[
int
]
=
None
,
)
->
List
[
int
]:
# Mistral Tokenizers should not add special tokens
input_ids
=
self
.
encode
(
promp
t
)
input_ids
=
self
.
encode
(
tex
t
)
if
truncation
:
input_ids
=
input_ids
[:
max_length
]
return
input_ids
def
encode
(
self
,
prompt
:
str
)
->
List
[
int
]:
def
encode
(
self
,
text
:
str
,
add_special_tokens
:
Optional
[
bool
]
=
None
)
->
List
[
int
]:
# `encode` should only be used for prompt completion
# it should never be used for chat_completion.
# For chat completion use `apply_chat_template`
return
self
.
tokenizer
.
encode
(
prompt
,
bos
=
True
,
eos
=
False
)
if
add_special_tokens
is
not
None
:
return
self
.
tokenizer
.
encode
(
text
,
bos
=
add_special_tokens
,
eos
=
add_special_tokens
)
else
:
return
self
.
tokenizer
.
encode
(
text
,
bos
=
True
,
eos
=
False
)
def
apply_chat_template
(
self
,
messages
:
List
[
"ChatCompletionMessageParam"
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment