Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8fb5dea5
Commit
8fb5dea5
authored
May 20, 2025
by
zhuwenwen
Browse files
support qiyuan-8b-v2 and FM9GForCausalLM
parent
a5aa55e8
Changes
15
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
1352 additions
and
24 deletions
+1352
-24
vllm/config.py
vllm/config.py
+2
-2
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+1
-1
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+20
-3
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+9
-5
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+2
-0
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+10
-2
vllm/inputs/preprocess.py
vllm/inputs/preprocess.py
+6
-3
vllm/model_executor/models/fm9g.py
vllm/model_executor/models/fm9g.py
+592
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+2
-0
vllm/transformers_utils/configs/fm9g.py
vllm/transformers_utils/configs/fm9g.py
+187
-0
vllm/transformers_utils/detokenizer.py
vllm/transformers_utils/detokenizer.py
+17
-4
vllm/transformers_utils/detokenizer_utils.py
vllm/transformers_utils/detokenizer_utils.py
+17
-3
vllm/transformers_utils/tokenizers/__init__.py
vllm/transformers_utils/tokenizers/__init__.py
+3
-1
vllm/transformers_utils/tokenizers/cpm_9g.py
vllm/transformers_utils/tokenizers/cpm_9g.py
+483
-0
No files found.
vllm/config.py
View file @
8fb5dea5
...
@@ -640,10 +640,10 @@ class ModelConfig:
...
@@ -640,10 +640,10 @@ class ModelConfig:
def
_verify_tokenizer_mode
(
self
)
->
None
:
def
_verify_tokenizer_mode
(
self
)
->
None
:
tokenizer_mode
=
self
.
tokenizer_mode
.
lower
()
tokenizer_mode
=
self
.
tokenizer_mode
.
lower
()
if
tokenizer_mode
not
in
[
"auto"
,
"slow"
,
"mistral"
,
"custom"
]:
if
tokenizer_mode
not
in
[
"auto"
,
"cpm"
,
"slow"
,
"mistral"
,
"custom"
]:
raise
ValueError
(
raise
ValueError
(
f
"Unknown tokenizer mode:
{
self
.
tokenizer_mode
}
. Must be "
f
"Unknown tokenizer mode:
{
self
.
tokenizer_mode
}
. Must be "
"either 'auto', 'slow', 'mistral' or 'custom'."
)
"either 'auto',
'cpm',
'slow', 'mistral' or 'custom'."
)
self
.
tokenizer_mode
=
tokenizer_mode
self
.
tokenizer_mode
=
tokenizer_mode
def
_get_preferred_task
(
def
_get_preferred_task
(
...
...
vllm/engine/arg_utils.py
View file @
8fb5dea5
...
@@ -421,7 +421,7 @@ class EngineArgs:
...
@@ -421,7 +421,7 @@ class EngineArgs:
'--tokenizer-mode'
,
'--tokenizer-mode'
,
type
=
str
,
type
=
str
,
default
=
EngineArgs
.
tokenizer_mode
,
default
=
EngineArgs
.
tokenizer_mode
,
choices
=
[
'auto'
,
'slow'
,
'mistral'
,
'custom'
],
choices
=
[
'auto'
,
'cpm'
,
'slow'
,
'mistral'
,
'custom'
],
help
=
'The tokenizer mode.
\n\n
* "auto" will use the '
help
=
'The tokenizer mode.
\n\n
* "auto" will use the '
'fast tokenizer if available.
\n
* "slow" will '
'fast tokenizer if available.
\n
* "slow" will '
'always use the slow tokenizer.
\n
* '
'always use the slow tokenizer.
\n
* '
...
...
vllm/engine/llm_engine.py
View file @
8fb5dea5
...
@@ -54,6 +54,8 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
...
@@ -54,6 +54,8 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer
)
init_tracer
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
# DEBUG add cpm tokenizer
from
vllm.transformers_utils.tokenizers
import
CPM9GTokenizer
from
vllm.transformers_utils.tokenizer_group
import
(
from
vllm.transformers_utils.tokenizer_group
import
(
TokenizerGroup
,
init_tokenizer_from_configs
)
TokenizerGroup
,
init_tokenizer_from_configs
)
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
...
@@ -250,10 +252,14 @@ class LLMEngine:
...
@@ -250,10 +252,14 @@ class LLMEngine:
self
.
log_stats
=
log_stats
self
.
log_stats
=
log_stats
self
.
use_cached_outputs
=
use_cached_outputs
self
.
use_cached_outputs
=
use_cached_outputs
if
not
self
.
model_config
.
skip_tokenizer_init
:
if
not
self
.
model_config
.
skip_tokenizer_init
and
self
.
model_config
.
tokenizer_mode
!=
"cpm"
:
self
.
tokenizer
=
self
.
_init_tokenizer
()
self
.
tokenizer
=
self
.
_init_tokenizer
()
self
.
detokenizer
=
Detokenizer
(
self
.
tokenizer
)
self
.
detokenizer
=
Detokenizer
(
self
.
tokenizer
)
tokenizer_group
=
self
.
get_tokenizer_group
()
tokenizer_group
=
self
.
get_tokenizer_group
()
elif
self
.
model_config
.
tokenizer_mode
==
"cpm"
:
self
.
tokenizer
=
CPM9GTokenizer
(
self
.
model_config
.
model
,
trust_remote_code
=
True
)
self
.
detokenizer
=
Detokenizer
(
self
.
tokenizer
,
self
.
model_config
.
tokenizer_mode
)
tokenizer_group
=
self
.
get_tokenizer_group
()
else
:
else
:
self
.
tokenizer
=
None
self
.
tokenizer
=
None
self
.
detokenizer
=
None
self
.
detokenizer
=
None
...
@@ -541,7 +547,10 @@ class LLMEngine:
...
@@ -541,7 +547,10 @@ class LLMEngine:
self
,
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
AnyTokenizer
:
)
->
AnyTokenizer
:
return
self
.
get_tokenizer_group
().
get_lora_tokenizer
(
lora_request
)
if
self
.
model_config
.
tokenizer_mode
==
"cpm"
:
return
self
.
tokenizer
else
:
return
self
.
get_tokenizer_group
().
get_lora_tokenizer
(
lora_request
)
def
_init_tokenizer
(
self
)
->
TokenizerGroup
:
def
_init_tokenizer
(
self
)
->
TokenizerGroup
:
return
init_tokenizer_from_configs
(
return
init_tokenizer_from_configs
(
...
@@ -592,7 +601,11 @@ class LLMEngine:
...
@@ -592,7 +601,11 @@ class LLMEngine:
# Create the sequences.
# Create the sequences.
block_size
=
self
.
cache_config
.
block_size
block_size
=
self
.
cache_config
.
block_size
seq_id
=
next
(
self
.
seq_counter
)
seq_id
=
next
(
self
.
seq_counter
)
eos_token_id
=
self
.
input_preprocessor
.
get_eos_token_id
(
lora_request
)
#DEBUG @TODO change tokenizer false
if
self
.
model_config
.
tokenizer_mode
==
"cpm"
:
eos_token_id
=
self
.
tokenizer
.
eos_id
else
:
eos_token_id
=
self
.
input_preprocessor
.
get_eos_token_id
(
lora_request
)
encoder_inputs
,
decoder_inputs
=
split_enc_dec_inputs
(
processed_inputs
)
encoder_inputs
,
decoder_inputs
=
split_enc_dec_inputs
(
processed_inputs
)
...
@@ -761,6 +774,10 @@ class LLMEngine:
...
@@ -761,6 +774,10 @@ class LLMEngine:
prompt
,
prompt
,
tokenizer
=
self
.
get_tokenizer
(
lora_request
=
lora_request
))
tokenizer
=
self
.
get_tokenizer
(
lora_request
=
lora_request
))
#DEBUG anrongqiao
if
self
.
model_config
.
tokenizer_mode
==
"cpm"
:
lora_request
=
None
processed_inputs
=
self
.
input_preprocessor
.
preprocess
(
processed_inputs
=
self
.
input_preprocessor
.
preprocess
(
prompt
,
prompt
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
...
...
vllm/engine/multiprocessing/client.py
View file @
8fb5dea5
...
@@ -48,6 +48,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
...
@@ -48,6 +48,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.utils
import
Device
,
deprecate_kwargs
from
vllm.utils
import
Device
,
deprecate_kwargs
from
vllm.transformers_utils.tokenizers
import
CPM9GTokenizer
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -98,10 +99,13 @@ class MQLLMEngineClient(EngineClient):
...
@@ -98,10 +99,13 @@ class MQLLMEngineClient(EngineClient):
self
.
decoding_config
=
engine_config
.
decoding_config
self
.
decoding_config
=
engine_config
.
decoding_config
# Create the tokenizer group.
# Create the tokenizer group.
self
.
tokenizer
=
init_tokenizer_from_configs
(
if
self
.
model_config
.
tokenizer_mode
!=
"cpm"
:
model_config
=
self
.
model_config
,
self
.
tokenizer
=
init_tokenizer_from_configs
(
scheduler_config
=
engine_config
.
scheduler_config
,
model_config
=
self
.
model_config
,
lora_config
=
engine_config
.
lora_config
)
scheduler_config
=
engine_config
.
scheduler_config
,
lora_config
=
engine_config
.
lora_config
)
else
:
self
.
tokenizer
=
CPM9GTokenizer
(
self
.
model_config
.
model
,
trust_remote_code
=
True
)
self
.
input_preprocessor
=
InputPreprocessor
(
self
.
model_config
,
self
.
input_preprocessor
=
InputPreprocessor
(
self
.
model_config
,
self
.
tokenizer
)
self
.
tokenizer
)
...
@@ -375,7 +379,7 @@ class MQLLMEngineClient(EngineClient):
...
@@ -375,7 +379,7 @@ class MQLLMEngineClient(EngineClient):
return
self
.
input_preprocessor
return
self
.
input_preprocessor
async
def
get_tokenizer
(
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
):
async
def
get_tokenizer
(
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
):
return
await
self
.
tokenizer
.
get_lora_tokenizer_async
(
lora_request
)
return
await
self
.
tokenizer
.
get_lora_tokenizer_async
(
lora_request
)
if
self
.
model_config
.
tokenizer_mode
!=
"cpm"
else
self
.
tokenizer
async
def
get_vllm_config
(
self
)
->
VllmConfig
:
async
def
get_vllm_config
(
self
)
->
VllmConfig
:
return
self
.
vllm_config
return
self
.
vllm_config
...
...
vllm/entrypoints/llm.py
View file @
8fb5dea5
...
@@ -164,6 +164,8 @@ class LLM:
...
@@ -164,6 +164,8 @@ class LLM:
self
,
self
,
model
:
str
,
model
:
str
,
tokenizer
:
Optional
[
str
]
=
None
,
tokenizer
:
Optional
[
str
]
=
None
,
#need change mode as "cpm" for 9g tokenizer
# tokenizer_mode: str = "cpm",
tokenizer_mode
:
str
=
"auto"
,
tokenizer_mode
:
str
=
"auto"
,
skip_tokenizer_init
:
bool
=
False
,
skip_tokenizer_init
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
8fb5dea5
...
@@ -47,6 +47,7 @@ from vllm.sequence import Logprob, PromptLogprobs
...
@@ -47,6 +47,7 @@ from vllm.sequence import Logprob, PromptLogprobs
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
log_tracing_disabled_warning
)
log_tracing_disabled_warning
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizers
import
CPM9GTokenizer
from
vllm.utils
import
is_list_of
,
make_async
,
random_uuid
from
vllm.utils
import
is_list_of
,
make_async
,
random_uuid
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -86,6 +87,10 @@ class OpenAIServing:
...
@@ -86,6 +87,10 @@ class OpenAIServing:
self
.
engine_client
=
engine_client
self
.
engine_client
=
engine_client
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
max_model_len
=
model_config
.
max_model_len
self
.
max_model_len
=
model_config
.
max_model_len
self
.
tokenizer_mode
=
model_config
.
tokenizer_mode
if
model_config
.
tokenizer_mode
==
"cpm"
:
self
.
tokenizer
=
CPM9GTokenizer
(
model_config
.
model
,
trust_remote_code
=
True
)
self
.
models
=
models
self
.
models
=
models
...
@@ -189,7 +194,10 @@ class OpenAIServing:
...
@@ -189,7 +194,10 @@ class OpenAIServing:
truncation
=
True
,
truncation
=
True
,
max_length
=
truncate_prompt_tokens
)
max_length
=
truncate_prompt_tokens
)
input_ids
=
encoded
.
input_ids
if
self
.
tokenizer_mode
==
"cpm"
:
input_ids
=
[
self
.
tokenizer
.
bos_id
]
+
self
.
tokenizer
.
encode
(
prompt
)
else
:
input_ids
=
encoded
.
input_ids
input_text
=
prompt
input_text
=
prompt
...
@@ -207,7 +215,7 @@ class OpenAIServing:
...
@@ -207,7 +215,7 @@ class OpenAIServing:
else
:
else
:
input_ids
=
prompt_ids
[
-
truncate_prompt_tokens
:]
input_ids
=
prompt_ids
[
-
truncate_prompt_tokens
:]
input_text
=
tokenizer
.
decode
(
input_ids
)
input_text
=
tokenizer
.
decode
(
input_ids
)
if
self
.
tokenizer_mode
!=
"cpm"
else
self
.
tokenizer
.
decode_all
(
input_ids
)
return
self
.
_validate_input
(
request
,
input_ids
,
input_text
)
return
self
.
_validate_input
(
request
,
input_ids
,
input_text
)
...
...
vllm/inputs/preprocess.py
View file @
8fb5dea5
...
@@ -201,9 +201,12 @@ class InputPreprocessor:
...
@@ -201,9 +201,12 @@ class InputPreprocessor:
"do_lower_case"
,
False
)):
"do_lower_case"
,
False
)):
prompt
=
prompt
.
lower
()
prompt
=
prompt
.
lower
()
return
tokenizer
.
encode
(
prompt
=
prompt
,
if
self
.
model_config
.
tokenizer_mode
==
"cpm"
:
lora_request
=
lora_request
,
return
[
tokenizer
.
bos_id
]
+
tokenizer
.
encode
(
prompt
)
add_special_tokens
=
add_special_tokens
)
else
:
return
tokenizer
.
encode
(
prompt
=
prompt
,
lora_request
=
lora_request
,
add_special_tokens
=
add_special_tokens
)
async
def
_tokenize_prompt_async
(
async
def
_tokenize_prompt_async
(
self
,
self
,
...
...
vllm/model_executor/models/fm9g.py
0 → 100644
View file @
8fb5dea5
This diff is collapsed.
Click to expand it.
vllm/model_executor/models/registry.py
View file @
8fb5dea5
...
@@ -53,6 +53,7 @@ _TEXT_GENERATION_MODELS = {
...
@@ -53,6 +53,7 @@ _TEXT_GENERATION_MODELS = {
"DeepseekV3ForCausalLM"
:
(
"deepseek_v2"
,
"DeepseekV3ForCausalLM"
),
"DeepseekV3ForCausalLM"
:
(
"deepseek_v2"
,
"DeepseekV3ForCausalLM"
),
"ExaoneForCausalLM"
:
(
"exaone"
,
"ExaoneForCausalLM"
),
"ExaoneForCausalLM"
:
(
"exaone"
,
"ExaoneForCausalLM"
),
"FalconForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"FalconForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"FM9GForCausalLM"
:
(
"fm9g"
,
"FM9GForCausalLM"
),
"Fairseq2LlamaForCausalLM"
:
(
"fairseq2_llama"
,
"Fairseq2LlamaForCausalLM"
),
"Fairseq2LlamaForCausalLM"
:
(
"fairseq2_llama"
,
"Fairseq2LlamaForCausalLM"
),
"GemmaForCausalLM"
:
(
"gemma"
,
"GemmaForCausalLM"
),
"GemmaForCausalLM"
:
(
"gemma"
,
"GemmaForCausalLM"
),
"Gemma2ForCausalLM"
:
(
"gemma2"
,
"Gemma2ForCausalLM"
),
"Gemma2ForCausalLM"
:
(
"gemma2"
,
"Gemma2ForCausalLM"
),
...
...
vllm/transformers_utils/configs/__init__.py
View file @
8fb5dea5
...
@@ -10,6 +10,7 @@ from vllm.transformers_utils.configs.exaone import ExaoneConfig
...
@@ -10,6 +10,7 @@ from vllm.transformers_utils.configs.exaone import ExaoneConfig
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
# `FalconConfig` class from the official HuggingFace transformers library.
from
vllm.transformers_utils.configs.falcon
import
RWConfig
from
vllm.transformers_utils.configs.falcon
import
RWConfig
from
vllm.transformers_utils.configs.fm9g
import
FM9GConfig
from
vllm.transformers_utils.configs.h2ovl
import
H2OVLChatConfig
from
vllm.transformers_utils.configs.h2ovl
import
H2OVLChatConfig
from
vllm.transformers_utils.configs.internvl
import
InternVLChatConfig
from
vllm.transformers_utils.configs.internvl
import
InternVLChatConfig
from
vllm.transformers_utils.configs.jais
import
JAISConfig
from
vllm.transformers_utils.configs.jais
import
JAISConfig
...
@@ -31,6 +32,7 @@ __all__ = [
...
@@ -31,6 +32,7 @@ __all__ = [
"Cohere2Config"
,
"Cohere2Config"
,
"DbrxConfig"
,
"DbrxConfig"
,
"DeepseekVLV2Config"
,
"DeepseekVLV2Config"
,
"FM9GConfig"
,
"MPTConfig"
,
"MPTConfig"
,
"RWConfig"
,
"RWConfig"
,
"H2OVLChatConfig"
,
"H2OVLChatConfig"
,
...
...
vllm/transformers_utils/configs/fm9g.py
0 → 100644
View file @
8fb5dea5
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""FM9G model configuration"""
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
FM9G_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{}
class
FM9GConfig
(
PretrainedConfig
):
r
"""
This is the configuration class to store the configuration of a [`FM9GModel`]. It is used to instantiate an FM9G
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the FM9G-7B.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the FM9G model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`FM9GModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
issue](https://github.com/pytorch/pytorch/issues/76232).
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
"""
model_type
=
"fm9g"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
def
__init__
(
self
,
vocab_size
=
32000
,
hidden_size
=
4096
,
intermediate_size
=
11008
,
num_hidden_layers
=
32
,
num_attention_heads
=
32
,
num_key_value_heads
=
None
,
hidden_act
=
"silu"
,
max_position_embeddings
=
2048
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
use_cache
=
True
,
pad_token_id
=
None
,
bos_token_id
=
1
,
eos_token_id
=
2
,
pretraining_tp
=
1
,
tie_word_embeddings
=
True
,
rope_theta
=
10000.0
,
rope_scaling
=
None
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
scale_emb
=
1
,
dim_model_base
=
1
,
scale_depth
=
1
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
pretraining_tp
=
pretraining_tp
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
_rope_scaling_validation
()
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
scale_emb
=
scale_emb
self
.
dim_model_base
=
dim_model_base
self
.
scale_depth
=
scale_depth
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
try
:
import
flash_attn
self
.
_attn_implementation
=
"flash_attention_2"
except
:
pass
def
_rope_scaling_validation
(
self
):
"""
Validate the `rope_scaling` configuration.
"""
if
self
.
rope_scaling
is
None
:
return
if
not
isinstance
(
self
.
rope_scaling
,
dict
)
or
len
(
self
.
rope_scaling
)
!=
2
:
raise
ValueError
(
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f
"got
{
self
.
rope_scaling
}
"
)
rope_scaling_type
=
self
.
rope_scaling
.
get
(
"type"
,
None
)
rope_scaling_factor
=
self
.
rope_scaling
.
get
(
"factor"
,
None
)
if
rope_scaling_type
is
None
or
rope_scaling_type
not
in
[
"linear"
,
"dynamic"
]:
raise
ValueError
(
f
"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got
{
rope_scaling_type
}
"
)
if
rope_scaling_factor
is
None
or
not
isinstance
(
rope_scaling_factor
,
float
)
or
rope_scaling_factor
<=
1.0
:
raise
ValueError
(
f
"`rope_scaling`'s factor field must be a float > 1, got
{
rope_scaling_factor
}
"
)
\ No newline at end of file
vllm/transformers_utils/detokenizer.py
View file @
8fb5dea5
...
@@ -14,8 +14,12 @@ from .tokenizer_group import TokenizerGroup
...
@@ -14,8 +14,12 @@ from .tokenizer_group import TokenizerGroup
class
Detokenizer
:
class
Detokenizer
:
"""Provides methods to decode the output of a model into text."""
"""Provides methods to decode the output of a model into text."""
def
__init__
(
self
,
tokenizer_group
:
TokenizerGroup
):
def
__init__
(
self
,
tokenizer_group
:
TokenizerGroup
,
mode
=
"auto"
):
self
.
tokenizer_group
=
tokenizer_group
self
.
mode
=
mode
if
self
.
mode
!=
"cpm"
:
self
.
tokenizer_group
=
tokenizer_group
else
:
self
.
tokenizer
=
tokenizer_group
def
get_tokenizer_for_seq
(
self
,
sequence
:
Sequence
)
->
AnyTokenizer
:
def
get_tokenizer_for_seq
(
self
,
sequence
:
Sequence
)
->
AnyTokenizer
:
"""Returns the HF tokenizer to use for a given sequence."""
"""Returns the HF tokenizer to use for a given sequence."""
...
@@ -44,7 +48,10 @@ class Detokenizer:
...
@@ -44,7 +48,10 @@ class Detokenizer:
# Only prompt, without the generated token.
# Only prompt, without the generated token.
all_token_ids
=
seq
.
get_token_ids
()
all_token_ids
=
seq
.
get_token_ids
()
prompt_token_ids
=
all_token_ids
[:
-
1
]
prompt_token_ids
=
all_token_ids
[:
-
1
]
tokenizer
=
self
.
get_tokenizer_for_seq
(
seq
)
if
self
.
mode
!=
"cpm"
:
tokenizer
=
self
.
get_tokenizer_for_seq
(
seq
)
else
:
tokenizer
=
self
.
tokenizer
prefix_offset
=
0
prefix_offset
=
0
read_offset
=
0
read_offset
=
0
next_iter_prefix_offset
=
0
next_iter_prefix_offset
=
0
...
@@ -76,6 +83,7 @@ class Detokenizer:
...
@@ -76,6 +83,7 @@ class Detokenizer:
skip_special_tokens
=
prms
.
skip_special_tokens
,
skip_special_tokens
=
prms
.
skip_special_tokens
,
spaces_between_special_tokens
=
prms
.
spaces_between_special_tokens
=
prms
.
spaces_between_special_tokens
,
spaces_between_special_tokens
,
mode
=
self
.
mode
,
)
)
sample_logprob
.
decoded_token
=
new_text
sample_logprob
.
decoded_token
=
new_text
...
@@ -109,7 +117,10 @@ class Detokenizer:
...
@@ -109,7 +117,10 @@ class Detokenizer:
"""
"""
all_input_ids
=
seq
.
get_token_ids
()
all_input_ids
=
seq
.
get_token_ids
()
token_id_generated_this_iteration
=
all_input_ids
[
-
1
]
token_id_generated_this_iteration
=
all_input_ids
[
-
1
]
tokenizer
=
self
.
get_tokenizer_for_seq
(
seq
)
if
self
.
mode
!=
"cpm"
:
tokenizer
=
self
.
get_tokenizer_for_seq
(
seq
)
else
:
tokenizer
=
self
.
tokenizer
# Convert prompt token IDs to tokens if necessary.
# Convert prompt token IDs to tokens if necessary.
# Do it here so that we don't have to repeat this
# Do it here so that we don't have to repeat this
...
@@ -131,6 +142,7 @@ class Detokenizer:
...
@@ -131,6 +142,7 @@ class Detokenizer:
read_offset
=
seq
.
read_offset
,
read_offset
=
seq
.
read_offset
,
skip_special_tokens
=
prms
.
skip_special_tokens
,
skip_special_tokens
=
prms
.
skip_special_tokens
,
spaces_between_special_tokens
=
prms
.
spaces_between_special_tokens
,
spaces_between_special_tokens
=
prms
.
spaces_between_special_tokens
,
mode
=
self
.
mode
,
)
)
# Decode logprobs
# Decode logprobs
...
@@ -156,6 +168,7 @@ class Detokenizer:
...
@@ -156,6 +168,7 @@ class Detokenizer:
skip_special_tokens
=
prms
.
skip_special_tokens
,
skip_special_tokens
=
prms
.
skip_special_tokens
,
spaces_between_special_tokens
=
prms
.
spaces_between_special_tokens
=
prms
.
spaces_between_special_tokens
,
spaces_between_special_tokens
,
mode
=
self
.
mode
,
)
)
sample_logprob
.
decoded_token
=
new_text
sample_logprob
.
decoded_token
=
new_text
...
...
vllm/transformers_utils/detokenizer_utils.py
View file @
8fb5dea5
...
@@ -16,6 +16,7 @@ def _convert_tokens_to_string_with_added_encoders(
...
@@ -16,6 +16,7 @@ def _convert_tokens_to_string_with_added_encoders(
output_tokens
:
List
[
str
],
output_tokens
:
List
[
str
],
skip_special_tokens
:
bool
,
skip_special_tokens
:
bool
,
spaces_between_special_tokens
:
bool
,
spaces_between_special_tokens
:
bool
,
mode
:
str
,
)
->
str
:
)
->
str
:
# Adapted from
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
...
@@ -24,7 +25,10 @@ def _convert_tokens_to_string_with_added_encoders(
...
@@ -24,7 +25,10 @@ def _convert_tokens_to_string_with_added_encoders(
# even when the loop body is very simple.
# even when the loop body is very simple.
sub_texts
:
List
[
str
]
=
[]
sub_texts
:
List
[
str
]
=
[]
current_sub_text
:
List
[
str
]
=
[]
current_sub_text
:
List
[
str
]
=
[]
all_special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
if
mode
!=
"cpm"
:
all_special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
else
:
all_special_tokens
=
tokenizer
.
_special_token_set
for
token
in
output_tokens
:
for
token
in
output_tokens
:
if
skip_special_tokens
and
token
in
all_special_tokens
:
if
skip_special_tokens
and
token
in
all_special_tokens
:
continue
continue
...
@@ -37,7 +41,10 @@ def _convert_tokens_to_string_with_added_encoders(
...
@@ -37,7 +41,10 @@ def _convert_tokens_to_string_with_added_encoders(
else
:
else
:
current_sub_text
.
append
(
token
)
current_sub_text
.
append
(
token
)
if
current_sub_text
:
if
current_sub_text
:
sub_text
=
tokenizer
.
convert_tokens_to_string
(
current_sub_text
)
if
mode
!=
"cpm"
:
sub_text
=
tokenizer
.
convert_tokens_to_string
(
current_sub_text
)
else
:
sub_text
=
tokenizer
.
decode
(
current_sub_text
)
sub_texts
.
append
(
sub_text
)
sub_texts
.
append
(
sub_text
)
if
spaces_between_special_tokens
:
if
spaces_between_special_tokens
:
return
" "
.
join
(
sub_texts
)
return
" "
.
join
(
sub_texts
)
...
@@ -104,6 +111,7 @@ def detokenize_incrementally(
...
@@ -104,6 +111,7 @@ def detokenize_incrementally(
read_offset
:
int
,
read_offset
:
int
,
skip_special_tokens
:
bool
=
False
,
skip_special_tokens
:
bool
=
False
,
spaces_between_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
,
mode
:
str
=
"cpm"
,
)
->
Tuple
[
List
[
str
],
str
,
int
,
int
]:
)
->
Tuple
[
List
[
str
],
str
,
int
,
int
]:
"""Detokenizes the input ids incrementally and returns the new tokens
"""Detokenizes the input ids incrementally and returns the new tokens
and the new text.
and the new text.
...
@@ -141,7 +149,11 @@ def detokenize_incrementally(
...
@@ -141,7 +149,11 @@ def detokenize_incrementally(
assert
prev_tokens
is
not
None
assert
prev_tokens
is
not
None
# If the new token id is out of bounds, return an empty string.
# If the new token id is out of bounds, return an empty string.
if
0
<=
new_token_id
<
len
(
tokenizer
):
if
mode
==
"cpm"
:
vocab_size
=
tokenizer
.
vocab_size
else
:
vocab_size
=
len
(
tokenizer
)
if
0
<=
new_token_id
<
vocab_size
:
# Put new_token_id in a list so skip_special_tokens is respected
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens
=
tokenizer
.
convert_ids_to_tokens
(
new_tokens
=
tokenizer
.
convert_ids_to_tokens
(
[
new_token_id
],
skip_special_tokens
=
skip_special_tokens
)
[
new_token_id
],
skip_special_tokens
=
skip_special_tokens
)
...
@@ -169,12 +181,14 @@ def detokenize_incrementally(
...
@@ -169,12 +181,14 @@ def detokenize_incrementally(
output_tokens
[
prefix_offset
:
read_offset
],
output_tokens
[
prefix_offset
:
read_offset
],
skip_special_tokens
=
skip_special_tokens
,
skip_special_tokens
=
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
mode
=
mode
,
)
)
new_text
=
_convert_tokens_to_string_with_added_encoders
(
new_text
=
_convert_tokens_to_string_with_added_encoders
(
tokenizer
,
tokenizer
,
output_tokens
[
prefix_offset
:],
output_tokens
[
prefix_offset
:],
skip_special_tokens
=
skip_special_tokens
,
skip_special_tokens
=
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
mode
=
mode
,
)
)
if
len
(
new_text
)
<=
len
(
prefix_text
)
or
new_text
.
endswith
(
"�"
):
if
len
(
new_text
)
<=
len
(
prefix_text
)
or
new_text
.
endswith
(
"�"
):
...
...
vllm/transformers_utils/tokenizers/__init__.py
View file @
8fb5dea5
...
@@ -2,8 +2,10 @@
...
@@ -2,8 +2,10 @@
from
.mistral
import
(
MistralTokenizer
,
maybe_serialize_tool_calls
,
from
.mistral
import
(
MistralTokenizer
,
maybe_serialize_tool_calls
,
truncate_tool_call_ids
,
validate_request_params
)
truncate_tool_call_ids
,
validate_request_params
)
from
vllm.transformers_utils.tokenizers.cpm_9g
import
CPM9GTokenizer
__all__
=
[
__all__
=
[
"MistralTokenizer"
,
"maybe_serialize_tool_calls"
,
"truncate_tool_call_ids"
,
"MistralTokenizer"
,
"maybe_serialize_tool_calls"
,
"truncate_tool_call_ids"
,
"validate_request_params"
"validate_request_params"
,
"CPM9GTokenizer"
]
]
vllm/transformers_utils/tokenizers/cpm_9g.py
0 → 100644
View file @
8fb5dea5
import
io
import
json
import
os
from
shutil
import
copyfile
from
typing
import
Any
,
Dict
,
IO
,
List
,
Optional
,
Tuple
import
pkg_resources
import
sentencepiece
as
spm
from
pytrie
import
StringTrie
from
transformers.tokenization_utils
import
AddedToken
,
PreTrainedTokenizer
from
transformers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
VOCAB_FILES_NAMES
=
{
"vocab_file"
:
"vocab.txt"
}
PRETRAINED_VOCAB_FILES_MAP
=
{
"vocab_file"
:
{},
"tokenizer_file"
:
{},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{}
class
CPM9GTokenizer
(
PreTrainedTokenizer
):
"""
CPM9G 分词器类。用于基于字节对编码的分词。
参数:
path (str, 可选): 词汇表文件的路径。
"""
vocab_files_names
=
VOCAB_FILES_NAMES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names
=
[
"input_ids"
,
"attention_mask"
]
def
__init__
(
self
,
vocab_file
:
Optional
[
str
]
=
None
,
unk_token
:
str
=
"<unk>"
,
bos_token
:
str
=
"<s>"
,
eos_token
:
str
=
"</s>"
,
pad_token
:
Optional
[
str
]
=
None
,
sp_model_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
add_bos_token
:
bool
=
True
,
add_eos_token
:
bool
=
False
,
clean_up_tokenization_spaces
:
bool
=
False
,
**
kwargs
,
):
self
.
sp_model_kwargs
=
sp_model_kwargs
or
{}
self
.
vocab_file
=
vocab_file
self
.
add_bos_token
=
add_bos_token
self
.
add_eos_token
=
add_eos_token
self
.
unk_token
=
unk_token
self
.
bos_token
=
bos_token
self
.
eos_token
=
eos_token
self
.
pad_token
=
pad_token
self
.
byte_list
:
List
[
str
]
=
(
[
f
"<0x0
{
hex
(
i
).
upper
()[
2
:]
}
>"
for
i
in
range
(
0x10
)]
+
[
f
"<0x
{
hex
(
i
).
upper
()[
2
:]
}
>"
for
i
in
range
(
0x10
,
0x100
)]
)
self
.
_special_token_set
=
set
([
self
.
unk_token
,
self
.
bos_token
,
self
.
eos_token
]
+
self
.
byte_list
)
if
vocab_file
:
if
'vocab.txt'
not
in
vocab_file
:
all_tokens
=
self
.
load_vocab
(
io
.
FileIO
(
os
.
path
.
join
(
vocab_file
,
VOCAB_FILES_NAMES
[
'vocab_file'
]),
"rb"
))
else
:
all_tokens
=
self
.
load_vocab
(
io
.
FileIO
(
VOCAB_FILES_NAMES
[
'vocab_file'
],
"rb"
))
self
.
encoder
:
Dict
[
str
,
int
]
=
{}
self
.
_special_encoder
:
Dict
[
str
,
int
]
=
{}
for
token
,
token_id
in
all_tokens
.
items
():
if
token
in
self
.
_special_token_set
:
self
.
_special_encoder
[
token
]
=
token_id
else
:
self
.
encoder
[
token
]
=
token_id
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
_byte_decoder
=
{
self
.
_special_encoder
[
token
]:
i
for
i
,
token
in
enumerate
(
self
.
byte_list
)}
self
.
_max_word_len
=
max
([
len
(
x
)
for
x
in
self
.
encoder
.
keys
()])
self
.
_len_word_first
=
{}
for
x
in
self
.
encoder
.
keys
():
if
not
x
[
0
]
in
self
.
_len_word_first
:
self
.
_len_word_first
[
x
[
0
]]
=
1
if
len
(
x
)
>
self
.
_len_word_first
[
x
[
0
]]:
self
.
_len_word_first
[
x
[
0
]]
=
len
(
x
)
self
.
tencoder
=
StringTrie
(
self
.
encoder
)
self
.
_max_token_id
=
self
.
vocab_size
-
1
super
().
__init__
(
bos_token
=
AddedToken
(
bos_token
,
lstrip
=
False
,
rstrip
=
False
),
eos_token
=
AddedToken
(
eos_token
,
lstrip
=
False
,
rstrip
=
False
),
unk_token
=
AddedToken
(
unk_token
,
lstrip
=
False
,
rstrip
=
False
),
pad_token
=
AddedToken
(
pad_token
,
lstrip
=
False
,
rstrip
=
False
)
if
pad_token
else
None
,
add_bos_token
=
add_bos_token
,
add_eos_token
=
add_eos_token
,
sp_model_kwargs
=
self
.
sp_model_kwargs
,
clean_up_tokenization_spaces
=
clean_up_tokenization_spaces
,
**
kwargs
,
)
def
__getstate__
(
self
)
->
Dict
[
str
,
Any
]:
state
=
self
.
__dict__
.
copy
()
state
[
"sp_model"
]
=
None
return
state
def
__setstate__
(
self
,
d
:
Dict
[
str
,
Any
])
->
None
:
self
.
__dict__
=
d
def
load_vocab
(
self
,
fp
:
IO
[
bytes
])
->
Dict
[
str
,
int
]:
"""
加载词汇表文件到字典中。
参数:
fp (IO[bytes]): 词汇表文件指针。
返回:
Dict[str, int]: 词汇表字典。
"""
vocab
:
Dict
[
str
,
int
]
=
{}
reader
=
io
.
TextIOWrapper
(
fp
,
encoding
=
"utf-8"
)
for
token
in
reader
.
readlines
():
token
=
token
.
strip
()
if
len
(
token
)
==
0
:
continue
token
=
json
.
loads
(
token
)
vocab
[
token
]
=
len
(
vocab
)
return
vocab
@
property
def
vocab_size
(
self
)
->
int
:
"""返回词汇表大小"""
return
len
(
self
.
encoder
)
+
len
(
self
.
_special_encoder
)
@
property
def
max_token_id
(
self
)
->
int
:
return
self
.
_max_token_id
@
property
def
eos_id
(
self
):
return
self
.
_special_encoder
[
self
.
eos_token
]
@
property
def
bos_id
(
self
):
return
self
.
_special_encoder
[
self
.
bos_token
]
@
property
def
unk_id
(
self
):
return
self
.
_special_encoder
[
self
.
unk_token
]
def
get_vocab
(
self
)
->
Dict
[
str
,
int
]:
"""返回词汇表作为字典"""
vocab
=
{
self
.
convert_ids_to_tokens
(
i
):
i
for
i
in
range
(
self
.
vocab_size
)}
vocab
.
update
(
self
.
added_tokens_encoder
)
return
vocab
def
_tokenize
(
self
,
text
:
str
)
->
List
[
str
]:
"""返回分词后的字符串"""
output_tokens
:
List
[
str
]
=
[]
st
=
0
while
st
<
len
(
text
):
piece
=
self
.
get_piece
(
text
[
st
:])
output_tokens
.
append
(
piece
)
st
+=
len
(
piece
)
return
output_tokens
def
_convert_token_to_id
(
self
,
token
:
str
)
->
int
:
"""使用词汇表将标记(字符串)转换为 id"""
return
self
.
encoder
.
get
(
token
,
self
.
unk_id
)
def
_convert_id_to_token
(
self
,
index
:
int
)
->
str
:
"""使用词汇表将索引(整数)转换为标记(字符串)"""
return
self
.
decoder
.
get
(
index
,
self
.
unk_token
)
def
convert_tokens_to_string
(
self
,
tokens
:
List
[
str
])
->
str
:
"""将标记序列(字符串)转换为单个字符串"""
current_sub_tokens
:
List
[
str
]
=
[]
out_string
=
""
prev_is_special
=
False
for
i
,
token
in
enumerate
(
tokens
):
if
token
in
self
.
_special_token_set
:
if
not
prev_is_special
and
i
!=
0
:
out_string
+=
" "
out_string
+=
self
.
decode
(
current_sub_tokens
)
+
token
prev_is_special
=
True
current_sub_tokens
=
[]
else
:
current_sub_tokens
.
append
(
token
)
prev_is_special
=
False
out_string
+=
self
.
sp_model
.
decode
(
current_sub_tokens
)
return
out_string
def
save_vocabulary
(
self
,
save_directory
:
str
,
filename_prefix
:
Optional
[
str
]
=
None
)
->
Tuple
[
str
]:
"""
保存词汇表和特殊标记文件到目录。
参数:
save_directory (str): 要保存词汇表的目录。
返回:
Tuple[str]: 保存的文件路径。
"""
if
not
os
.
path
.
isdir
(
save_directory
):
raise
ValueError
(
f
"Vocabulary path (
{
save_directory
}
) should be a directory"
)
out_vocab_file
=
os
.
path
.
join
(
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"vocab_file"
],
)
if
os
.
path
.
abspath
(
self
.
vocab_file
)
!=
os
.
path
.
abspath
(
out_vocab_file
)
and
os
.
path
.
isfile
(
self
.
vocab_file
):
copyfile
(
self
.
vocab_file
,
out_vocab_file
)
elif
not
os
.
path
.
isfile
(
self
.
vocab_file
):
with
open
(
out_vocab_file
,
"wb"
)
as
fi
:
fi
.
write
(
self
.
sp_model
.
serialized_model_proto
())
return
(
out_vocab_file
,
)
def
build_inputs_with_special_tokens
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
bos_token_id
=
[
self
.
bos_token_id
]
if
self
.
add_bos_token
else
[]
eos_token_id
=
[
self
.
eos_token_id
]
if
self
.
add_eos_token
else
[]
output
=
bos_token_id
+
token_ids_0
+
eos_token_id
if
token_ids_1
is
not
None
:
output
=
output
+
bos_token_id
+
token_ids_1
+
eos_token_id
return
output
def
get_special_tokens_mask
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
,
already_has_special_tokens
:
bool
=
False
)
->
List
[
int
]:
"""
获取从未添加特殊标记的标记列表中检索到的序列 id。
在使用分词器的 `prepare_for_model` 方法添加特殊标记时调用此方法。
参数:
token_ids_0 (List[int]): id 列表。
token_ids_1 (List[int], 可选): 序列对的可选第二 id 列表。
already_has_special_tokens (bool, 可选, 默认值为 False):
标记列表是否已使用模型的特殊标记进行格式化。
返回:
List[int]: 一个包含整数(0 或 1)的列表。1 表示特殊标记,0 表示序列标记。
"""
if
already_has_special_tokens
:
return
super
().
get_special_tokens_mask
(
token_ids_0
=
token_ids_0
,
token_ids_1
=
token_ids_1
,
already_has_special_tokens
=
True
,
)
bos_token_id
=
[
1
]
if
self
.
add_bos_token
else
[]
eos_token_id
=
[
1
]
if
self
.
add_eos_token
else
[]
if
token_ids_1
is
None
:
return
bos_token_id
+
([
0
]
*
len
(
token_ids_0
))
+
eos_token_id
return
bos_token_id
+
([
0
]
*
len
(
token_ids_0
))
+
eos_token_id
+
bos_token_id
+
([
0
]
*
len
(
token_ids_1
))
+
eos_token_id
def
create_token_type_ids_from_sequences
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
"""
从传递的两个序列创建掩码,用于序列对分类任务。
参数:
token_ids_0 (List[int]): id 列表。
token_ids_1 (List[int], 可选): 序列对的可选第二 id 列表。
返回:
List[int]: 根据给定序列的标记类型 id 列表。
"""
bos_token_id
=
[
self
.
bos_token_id
]
if
self
.
add_bos_token
else
[]
eos_token_id
=
[
self
.
eos_token_id
]
if
self
.
add_eos_token
else
[]
output
=
[
0
]
*
len
(
bos_token_id
+
token_ids_0
+
eos_token_id
)
if
token_ids_1
is
not
None
:
output
+=
[
1
]
*
len
(
bos_token_id
+
token_ids_1
+
eos_token_id
)
return
output
def
get_piece
(
self
,
text
:
str
)
->
str
:
"""
获取文本中的分词片段。
参数:
text (str): 输入文本。
返回:
str: 分词片段。
"""
if
text
[
0
]
in
self
.
_len_word_first
:
text
=
text
[:
self
.
_len_word_first
[
text
[
0
]]]
len_text
=
len
(
text
)
for
i
in
range
(
len
(
text
)):
sub
=
text
[:
len_text
-
i
]
if
sub
in
self
.
encoder
:
return
sub
return
text
[
0
]
def
encode
(
self
,
text
:
str
)
->
List
[
int
]:
"""
将文本编码为 ID 列表。
参数:
text (str): 输入文本。
返回:
List[int]: 编码后的 ID 列表。
"""
#if len(text) > 20480:
# return [0 for _ in range(20480)]
ret
=
[]
for
x
in
self
.
_tokenize
(
text
):
if
x
in
self
.
encoder
:
ret
.
append
(
self
.
encoder
[
x
])
else
:
ret
.
extend
(
self
.
_encode_unicode
(
x
))
return
ret
def
decode_all
(
self
,
tokens
:
List
[
int
]):
"""Decode ids into a string."""
ret
=
[]
st
=
0
while
st
<
len
(
tokens
):
if
tokens
[
st
]
in
self
.
decoder
:
ret
.
append
(
self
.
decoder
[
tokens
[
st
]])
st
+=
1
elif
tokens
[
st
]
in
self
.
_byte_decoder
:
if
(
st
+
3
<
len
(
tokens
)
and
tokens
[
st
+
1
]
in
self
.
_byte_decoder
and
tokens
[
st
+
2
]
in
self
.
_byte_decoder
and
tokens
[
st
+
3
]
in
self
.
_byte_decoder
):
first_id
=
self
.
_byte_decoder
[
tokens
[
st
]]
plane_id
=
self
.
_byte_decoder
[
tokens
[
st
+
1
]]
row_id
=
self
.
_byte_decoder
[
tokens
[
st
+
2
]]
cell_id
=
self
.
_byte_decoder
[
tokens
[
st
+
3
]]
ret
.
append
(
int
.
to_bytes
(
first_id
<<
24
|
plane_id
<<
16
|
row_id
<<
8
|
cell_id
,
4
,
"big"
).
decode
(
"utf-8"
)
)
st
+=
4
elif
(
st
+
2
<
len
(
tokens
)
and
tokens
[
st
+
1
]
in
self
.
_byte_decoder
and
tokens
[
st
+
2
]
in
self
.
_byte_decoder
):
plane_id
=
self
.
_byte_decoder
[
tokens
[
st
]]
row_id
=
self
.
_byte_decoder
[
tokens
[
st
+
1
]]
cell_id
=
self
.
_byte_decoder
[
tokens
[
st
+
2
]]
ret
.
append
(
int
.
to_bytes
(
plane_id
<<
16
|
row_id
<<
8
|
cell_id
,
3
,
"big"
).
decode
(
"utf-8"
))
st
+=
3
elif
st
+
1
<
len
(
tokens
)
and
tokens
[
st
+
1
]
in
self
.
_byte_decoder
:
row_id
=
self
.
_byte_decoder
[
tokens
[
st
]]
cell_id
=
self
.
_byte_decoder
[
tokens
[
st
+
1
]]
ret
.
append
(
int
.
to_bytes
(
row_id
<<
8
|
cell_id
,
2
,
"big"
).
decode
(
"utf-8"
))
st
+=
2
else
:
cell_id
=
self
.
_byte_decoder
[
tokens
[
st
]]
ret
.
append
(
int
.
to_bytes
(
cell_id
,
1
,
"big"
).
decode
(
"utf-8"
))
st
+=
1
elif
tokens
[
st
]
==
self
.
eos_id
:
ret
.
append
(
self
.
eos_token
)
st
+=
1
elif
tokens
[
st
]
==
self
.
bos_id
:
ret
.
append
(
self
.
bos_token
)
st
+=
1
else
:
ret
.
append
(
self
.
unk_token
)
st
+=
1
return
""
.
join
(
ret
)
def
decode
(
self
,
tokens
:
List
[
int
])
->
str
:
"""
将 ID 列表解码为字符串。
参数:
tokens (List[int]): ID 列表。
返回:
str: 解码后的字符串。
"""
ret
=
[]
st
=
0
while
st
<
len
(
tokens
):
if
tokens
[
st
]
in
self
.
_byte_decoder
:
if
(
st
+
3
<
len
(
tokens
)
and
tokens
[
st
+
1
]
in
self
.
_byte_decoder
and
tokens
[
st
+
2
]
in
self
.
_byte_decoder
and
tokens
[
st
+
3
]
in
self
.
_byte_decoder
):
first_id
=
self
.
_byte_decoder
[
tokens
[
st
]]
plane_id
=
self
.
_byte_decoder
[
tokens
[
st
+
1
]]
row_id
=
self
.
_byte_decoder
[
tokens
[
st
+
2
]]
cell_id
=
self
.
_byte_decoder
[
tokens
[
st
+
3
]]
ret
.
append
(
int
.
to_bytes
(
first_id
<<
24
|
plane_id
<<
16
|
row_id
<<
8
|
cell_id
,
4
,
"big"
).
decode
(
"utf-8"
)
)
st
+=
4
elif
(
st
+
2
<
len
(
tokens
)
and
tokens
[
st
+
1
]
in
self
.
_byte_decoder
and
tokens
[
st
+
2
]
in
self
.
_byte_decoder
):
plane_id
=
self
.
_byte_decoder
[
tokens
[
st
]]
row_id
=
self
.
_byte_decoder
[
tokens
[
st
+
1
]]
cell_id
=
self
.
_byte_decoder
[
tokens
[
st
+
2
]]
ret
.
append
(
int
.
to_bytes
(
plane_id
<<
16
|
row_id
<<
8
|
cell_id
,
3
,
"big"
).
decode
(
"utf-8"
))
st
+=
3
elif
st
+
1
<
len
(
tokens
)
and
tokens
[
st
+
1
]
in
self
.
_byte_decoder
:
row_id
=
self
.
_byte_decoder
[
tokens
[
st
]]
cell_id
=
self
.
_byte_decoder
[
tokens
[
st
+
1
]]
ret
.
append
(
int
.
to_bytes
(
row_id
<<
8
|
cell_id
,
2
,
"big"
).
decode
(
"utf-8"
))
st
+=
2
else
:
cell_id
=
self
.
_byte_decoder
[
tokens
[
st
]]
ret
.
append
(
int
.
to_bytes
(
cell_id
,
1
,
"big"
).
decode
(
"utf-8"
))
st
+=
1
elif
tokens
[
st
]
==
self
.
eos_id
:
ret
.
append
(
self
.
eos_token
)
st
+=
1
elif
tokens
[
st
]
==
self
.
bos_id
:
ret
.
append
(
self
.
bos_token
)
st
+=
1
else
:
ret
.
append
(
tokens
[
st
])
st
+=
1
#else:
# ret.append(self.unk_token)
# st += 1
return
''
.
join
(
ret
)
def
_encode_unicode
(
self
,
token
:
str
)
->
List
[
int
]:
"""
将 Unicode 编码包装到一个辅助函数中。
参数:
token (str): 要编码的标记。
返回:
List[int]: 编码后的 ID 列表。
"""
ids
=
[]
utf8_id
=
token
.
encode
(
"utf-8"
)
for
_id
in
utf8_id
:
ids
.
append
(
self
.
_special_encoder
[
self
.
byte_list
[
_id
]])
return
ids
def
next_token
(
self
,
text
:
str
)
->
Tuple
[
str
,
List
[
int
]]:
"""
快速获取下一个匹配的标记。
参数:
text (str): 输入文本。
返回:
Tuple[str, List[int]]: 匹配的标记及其 ID 列表。
"""
token
,
token_id
=
self
.
tencoder
.
longest_prefix_item
(
text
,
(
None
,
None
))
if
token
is
None
:
token
=
text
[
0
]
token_ids
=
self
.
_encode_unicode
(
token
)
else
:
token_ids
=
[
token_id
]
return
token
,
token_ids
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment