Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8fb5dea5
Commit
8fb5dea5
authored
May 20, 2025
by
zhuwenwen
Browse files
support qiyuan-8b-v2 and FM9GForCausalLM
parent
a5aa55e8
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
1352 additions
and
24 deletions
+1352
-24
vllm/config.py
vllm/config.py
+2
-2
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+1
-1
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+20
-3
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+9
-5
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+2
-0
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+10
-2
vllm/inputs/preprocess.py
vllm/inputs/preprocess.py
+6
-3
vllm/model_executor/models/fm9g.py
vllm/model_executor/models/fm9g.py
+592
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+2
-0
vllm/transformers_utils/configs/fm9g.py
vllm/transformers_utils/configs/fm9g.py
+187
-0
vllm/transformers_utils/detokenizer.py
vllm/transformers_utils/detokenizer.py
+17
-4
vllm/transformers_utils/detokenizer_utils.py
vllm/transformers_utils/detokenizer_utils.py
+17
-3
vllm/transformers_utils/tokenizers/__init__.py
vllm/transformers_utils/tokenizers/__init__.py
+3
-1
vllm/transformers_utils/tokenizers/cpm_9g.py
vllm/transformers_utils/tokenizers/cpm_9g.py
+483
-0
No files found.
vllm/config.py
View file @
8fb5dea5
...
...
@@ -640,10 +640,10 @@ class ModelConfig:
def
_verify_tokenizer_mode
(
self
)
->
None
:
tokenizer_mode
=
self
.
tokenizer_mode
.
lower
()
if
tokenizer_mode
not
in
[
"auto"
,
"slow"
,
"mistral"
,
"custom"
]:
if
tokenizer_mode
not
in
[
"auto"
,
"cpm"
,
"slow"
,
"mistral"
,
"custom"
]:
raise
ValueError
(
f
"Unknown tokenizer mode:
{
self
.
tokenizer_mode
}
. Must be "
"either 'auto', 'slow', 'mistral' or 'custom'."
)
"either 'auto',
'cpm',
'slow', 'mistral' or 'custom'."
)
self
.
tokenizer_mode
=
tokenizer_mode
def
_get_preferred_task
(
...
...
vllm/engine/arg_utils.py
View file @
8fb5dea5
...
...
@@ -421,7 +421,7 @@ class EngineArgs:
'--tokenizer-mode'
,
type
=
str
,
default
=
EngineArgs
.
tokenizer_mode
,
choices
=
[
'auto'
,
'slow'
,
'mistral'
,
'custom'
],
choices
=
[
'auto'
,
'cpm'
,
'slow'
,
'mistral'
,
'custom'
],
help
=
'The tokenizer mode.
\n\n
* "auto" will use the '
'fast tokenizer if available.
\n
* "slow" will '
'always use the slow tokenizer.
\n
* '
...
...
vllm/engine/llm_engine.py
View file @
8fb5dea5
...
...
@@ -54,6 +54,8 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
# DEBUG add cpm tokenizer
from
vllm.transformers_utils.tokenizers
import
CPM9GTokenizer
from
vllm.transformers_utils.tokenizer_group
import
(
TokenizerGroup
,
init_tokenizer_from_configs
)
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
...
...
@@ -250,10 +252,14 @@ class LLMEngine:
self
.
log_stats
=
log_stats
self
.
use_cached_outputs
=
use_cached_outputs
if
not
self
.
model_config
.
skip_tokenizer_init
:
if
not
self
.
model_config
.
skip_tokenizer_init
and
self
.
model_config
.
tokenizer_mode
!=
"cpm"
:
self
.
tokenizer
=
self
.
_init_tokenizer
()
self
.
detokenizer
=
Detokenizer
(
self
.
tokenizer
)
tokenizer_group
=
self
.
get_tokenizer_group
()
elif
self
.
model_config
.
tokenizer_mode
==
"cpm"
:
self
.
tokenizer
=
CPM9GTokenizer
(
self
.
model_config
.
model
,
trust_remote_code
=
True
)
self
.
detokenizer
=
Detokenizer
(
self
.
tokenizer
,
self
.
model_config
.
tokenizer_mode
)
tokenizer_group
=
self
.
get_tokenizer_group
()
else
:
self
.
tokenizer
=
None
self
.
detokenizer
=
None
...
...
@@ -541,6 +547,9 @@ class LLMEngine:
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
AnyTokenizer
:
if
self
.
model_config
.
tokenizer_mode
==
"cpm"
:
return
self
.
tokenizer
else
:
return
self
.
get_tokenizer_group
().
get_lora_tokenizer
(
lora_request
)
def
_init_tokenizer
(
self
)
->
TokenizerGroup
:
...
...
@@ -592,6 +601,10 @@ class LLMEngine:
# Create the sequences.
block_size
=
self
.
cache_config
.
block_size
seq_id
=
next
(
self
.
seq_counter
)
#DEBUG @TODO change tokenizer false
if
self
.
model_config
.
tokenizer_mode
==
"cpm"
:
eos_token_id
=
self
.
tokenizer
.
eos_id
else
:
eos_token_id
=
self
.
input_preprocessor
.
get_eos_token_id
(
lora_request
)
encoder_inputs
,
decoder_inputs
=
split_enc_dec_inputs
(
processed_inputs
)
...
...
@@ -761,6 +774,10 @@ class LLMEngine:
prompt
,
tokenizer
=
self
.
get_tokenizer
(
lora_request
=
lora_request
))
#DEBUG anrongqiao
if
self
.
model_config
.
tokenizer_mode
==
"cpm"
:
lora_request
=
None
processed_inputs
=
self
.
input_preprocessor
.
preprocess
(
prompt
,
lora_request
=
lora_request
,
...
...
vllm/engine/multiprocessing/client.py
View file @
8fb5dea5
...
...
@@ -48,6 +48,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.utils
import
Device
,
deprecate_kwargs
from
vllm.transformers_utils.tokenizers
import
CPM9GTokenizer
logger
=
init_logger
(
__name__
)
...
...
@@ -98,10 +99,13 @@ class MQLLMEngineClient(EngineClient):
self
.
decoding_config
=
engine_config
.
decoding_config
# Create the tokenizer group.
if
self
.
model_config
.
tokenizer_mode
!=
"cpm"
:
self
.
tokenizer
=
init_tokenizer_from_configs
(
model_config
=
self
.
model_config
,
scheduler_config
=
engine_config
.
scheduler_config
,
lora_config
=
engine_config
.
lora_config
)
else
:
self
.
tokenizer
=
CPM9GTokenizer
(
self
.
model_config
.
model
,
trust_remote_code
=
True
)
self
.
input_preprocessor
=
InputPreprocessor
(
self
.
model_config
,
self
.
tokenizer
)
...
...
@@ -375,7 +379,7 @@ class MQLLMEngineClient(EngineClient):
return
self
.
input_preprocessor
async
def
get_tokenizer
(
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
):
return
await
self
.
tokenizer
.
get_lora_tokenizer_async
(
lora_request
)
return
await
self
.
tokenizer
.
get_lora_tokenizer_async
(
lora_request
)
if
self
.
model_config
.
tokenizer_mode
!=
"cpm"
else
self
.
tokenizer
async
def
get_vllm_config
(
self
)
->
VllmConfig
:
return
self
.
vllm_config
...
...
vllm/entrypoints/llm.py
View file @
8fb5dea5
...
...
@@ -164,6 +164,8 @@ class LLM:
self
,
model
:
str
,
tokenizer
:
Optional
[
str
]
=
None
,
#need change mode as "cpm" for 9g tokenizer
# tokenizer_mode: str = "cpm",
tokenizer_mode
:
str
=
"auto"
,
skip_tokenizer_init
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
8fb5dea5
...
...
@@ -47,6 +47,7 @@ from vllm.sequence import Logprob, PromptLogprobs
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
log_tracing_disabled_warning
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizers
import
CPM9GTokenizer
from
vllm.utils
import
is_list_of
,
make_async
,
random_uuid
logger
=
init_logger
(
__name__
)
...
...
@@ -86,6 +87,10 @@ class OpenAIServing:
self
.
engine_client
=
engine_client
self
.
model_config
=
model_config
self
.
max_model_len
=
model_config
.
max_model_len
self
.
tokenizer_mode
=
model_config
.
tokenizer_mode
if
model_config
.
tokenizer_mode
==
"cpm"
:
self
.
tokenizer
=
CPM9GTokenizer
(
model_config
.
model
,
trust_remote_code
=
True
)
self
.
models
=
models
...
...
@@ -189,6 +194,9 @@ class OpenAIServing:
truncation
=
True
,
max_length
=
truncate_prompt_tokens
)
if
self
.
tokenizer_mode
==
"cpm"
:
input_ids
=
[
self
.
tokenizer
.
bos_id
]
+
self
.
tokenizer
.
encode
(
prompt
)
else
:
input_ids
=
encoded
.
input_ids
input_text
=
prompt
...
...
@@ -207,7 +215,7 @@ class OpenAIServing:
else
:
input_ids
=
prompt_ids
[
-
truncate_prompt_tokens
:]
input_text
=
tokenizer
.
decode
(
input_ids
)
input_text
=
tokenizer
.
decode
(
input_ids
)
if
self
.
tokenizer_mode
!=
"cpm"
else
self
.
tokenizer
.
decode_all
(
input_ids
)
return
self
.
_validate_input
(
request
,
input_ids
,
input_text
)
...
...
vllm/inputs/preprocess.py
View file @
8fb5dea5
...
...
@@ -201,6 +201,9 @@ class InputPreprocessor:
"do_lower_case"
,
False
)):
prompt
=
prompt
.
lower
()
if
self
.
model_config
.
tokenizer_mode
==
"cpm"
:
return
[
tokenizer
.
bos_id
]
+
tokenizer
.
encode
(
prompt
)
else
:
return
tokenizer
.
encode
(
prompt
=
prompt
,
lora_request
=
lora_request
,
add_special_tokens
=
add_special_tokens
)
...
...
vllm/model_executor/models/fm9g.py
0 → 100644
View file @
8fb5dea5
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only FM9G model compatible with HuggingFace weights."""
import
math
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
,
Union
,
List
import
torch
from
torch
import
nn
from
vllm.transformers_utils.configs
import
FM9GConfig
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.activation
import
FatreluAndMul
,
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
class
FM9GMoE
(
nn
.
Module
):
"""A tensor-parallel MoE implementation that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def
__init__
(
self
,
num_experts
:
int
,
top_k
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
):
super
().
__init__
()
self
.
tp_size
=
tp_size
or
get_tensor_model_parallel_world_size
()
self
.
num_total_experts
=
num_experts
self
.
top_k
=
top_k
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
//
self
.
tp_size
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
self
.
gate
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
num_total_experts
,
bias
=
False
,
params_dtype
=
self
.
params_dtype
,
quant_config
=
None
)
self
.
ws
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
2
*
self
.
intermediate_size
,
self
.
hidden_size
,
device
=
current_platform
.
device_type
,
dtype
=
self
.
params_dtype
))
self
.
w2s
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
self
.
hidden_size
,
self
.
intermediate_size
,
device
=
current_platform
.
device_type
,
dtype
=
self
.
params_dtype
))
set_weight_attrs
(
self
.
ws
,
{
"weight_loader"
:
self
.
weight_loader
,
})
set_weight_attrs
(
self
.
w2s
,
{
"weight_loader"
:
self
.
weight_loader
,
})
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
expert_id
:
int
):
tp_rank
=
get_tensor_model_parallel_rank
()
param_data
=
param
.
data
shard_size
=
self
.
intermediate_size
shard
=
slice
(
tp_rank
*
shard_size
,
(
tp_rank
+
1
)
*
shard_size
)
if
weight_name
.
endswith
(
"w1.weight"
):
param_data
[
expert_id
,
0
:
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
if
weight_name
.
endswith
(
"w3.weight"
):
param_data
[
expert_id
,
shard_size
:
2
*
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
if
weight_name
.
endswith
(
"w2.weight"
):
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_size
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_size
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
fused_moe
(
hidden_states
,
self
.
ws
,
self
.
w2s
,
router_logits
,
self
.
top_k
,
renormalize
=
True
,
inplace
=
True
)
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_size
)
class
FM9GMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
hidden_act_param
:
float
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
)
if
hidden_act
==
"silu"
:
self
.
act_fn
=
SiluAndMul
()
elif
hidden_act
==
"fatrelu"
:
self
.
act_fn
=
FatreluAndMul
(
threshold
=
hidden_act_param
)
else
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu and fatrelu are supported for now."
)
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
FM9GAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
# set rope as fp32 instead of bf16
self
.
rotary_emb
.
cos_sin_cache
=
self
.
rotary_emb
.
_compute_cos_sin_cache
(
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
orig_dtype
=
q
.
dtype
q
,
k
=
q
.
float
(),
k
.
float
()
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
q
.
to
(
orig_dtype
),
k
.
to
(
orig_dtype
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
FM9GDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
FM9GConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
cache_config
=
cache_config
self
.
quant_config
=
quant_config
self
.
hidden_size
=
config
.
hidden_size
self
.
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
self
.
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
self
.
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
prefix
=
prefix
self
.
_init_attn_block
()
self
.
_init_ffn_block
()
def
_init_attn_block
(
self
):
self
.
input_layernorm
=
RMSNorm
(
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
rms_norm_eps
)
self
.
self_attn
=
FM9GAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
self
.
config
.
num_attention_heads
,
num_kv_heads
=
self
.
config
.
num_key_value_heads
,
rope_theta
=
self
.
rope_theta
,
rope_scaling
=
self
.
rope_scaling
,
max_position_embeddings
=
self
.
max_position_embeddings
,
cache_config
=
self
.
cache_config
,
quant_config
=
self
.
quant_config
,
prefix
=
f
"
{
self
.
prefix
}
.self_attn"
,
)
def
_init_ffn_block
(
self
):
self
.
post_attention_layernorm
=
RMSNorm
(
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
rms_norm_eps
)
self
.
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
0
)
if
self
.
num_experts
==
0
:
self
.
mlp
=
FM9GMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
self
.
config
.
intermediate_size
,
hidden_act
=
self
.
config
.
hidden_act
,
hidden_act_param
=
getattr
(
self
.
config
,
"hidden_act_param"
,
0.
),
quant_config
=
self
.
quant_config
,
)
else
:
self
.
mlp
=
FM9GMoE
(
num_experts
=
self
.
config
.
num_experts
,
top_k
=
self
.
config
.
num_experts_per_tok
,
hidden_size
=
self
.
config
.
hidden_size
,
intermediate_size
=
self
.
config
.
intermediate_size
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
)
hidden_states
=
residual
+
hidden_states
*
\
(
self
.
config
.
scale_depth
/
math
.
sqrt
(
self
.
config
.
num_hidden_layers
))
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
*
\
(
self
.
config
.
scale_depth
/
math
.
sqrt
(
self
.
config
.
num_hidden_layers
))
return
hidden_states
,
None
@
support_torch_compile
class
FM9GModel
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
cache_config
=
cache_config
self
.
quant_config
=
quant_config
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
org_vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
)
self
.
num_experts
=
getattr
(
self
.
config
,
"num_experts"
,
0
)
self
.
_init_layers
(
prefix
,
config
,
cache_config
,
quant_config
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
self
.
config
.
hidden_size
))
def
_init_layers
(
self
,
prefix
:
str
,
config
:
FM9GConfig
,
cache_config
:
Optional
[
CacheConfig
],
quant_config
:
Optional
[
QuantizationConfig
],
):
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
FM9GDecoderLayer
(
config
,
cache_config
,
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
embedding
=
self
.
embed_tokens
(
input_ids
)
return
embedding
*
self
.
config
.
scale_emb
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
else
:
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
,
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
expert_params_mapping
=
[
# (param_name, weight_name, expert_id)
(
"ws"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2s"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
)
for
expert_id
in
range
(
self
.
num_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
param_name
,
weight_name
,
expert_id
in
expert_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
weight_name
,
expert_id
=
expert_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
FM9GForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
# LoRA specific attributes
embedding_modules
=
{
"embed_tokens"
:
"input_embeddings"
,
"lm_head"
:
"output_embeddings"
,
}
embedding_padding_modules
=
[
"lm_head"
]
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
prefix
=
prefix
self
.
vllm_config
=
vllm_config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
cache_config
=
cache_config
self
.
quant_config
=
quant_config
self
.
model
=
self
.
_init_model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
quant_config
=
quant_config
,
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
lm_head
.
tie_weights
(
self
.
model
.
embed_tokens
)
self
.
scale_width
=
self
.
config
.
hidden_size
/
self
.
config
.
dim_model_base
self
.
logits_processor
=
LogitsProcessor
(
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
_init_model
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
return
FM9GModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
hidden_states
=
hidden_states
/
self
.
scale_width
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
([
"lm_head."
]
if
self
.
config
.
tie_word_embeddings
else
None
),
)
return
loader
.
load_weights
(
weights
)
\ No newline at end of file
vllm/model_executor/models/registry.py
View file @
8fb5dea5
...
...
@@ -53,6 +53,7 @@ _TEXT_GENERATION_MODELS = {
"DeepseekV3ForCausalLM"
:
(
"deepseek_v2"
,
"DeepseekV3ForCausalLM"
),
"ExaoneForCausalLM"
:
(
"exaone"
,
"ExaoneForCausalLM"
),
"FalconForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"FM9GForCausalLM"
:
(
"fm9g"
,
"FM9GForCausalLM"
),
"Fairseq2LlamaForCausalLM"
:
(
"fairseq2_llama"
,
"Fairseq2LlamaForCausalLM"
),
"GemmaForCausalLM"
:
(
"gemma"
,
"GemmaForCausalLM"
),
"Gemma2ForCausalLM"
:
(
"gemma2"
,
"Gemma2ForCausalLM"
),
...
...
vllm/transformers_utils/configs/__init__.py
View file @
8fb5dea5
...
...
@@ -10,6 +10,7 @@ from vllm.transformers_utils.configs.exaone import ExaoneConfig
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
from
vllm.transformers_utils.configs.falcon
import
RWConfig
from
vllm.transformers_utils.configs.fm9g
import
FM9GConfig
from
vllm.transformers_utils.configs.h2ovl
import
H2OVLChatConfig
from
vllm.transformers_utils.configs.internvl
import
InternVLChatConfig
from
vllm.transformers_utils.configs.jais
import
JAISConfig
...
...
@@ -31,6 +32,7 @@ __all__ = [
"Cohere2Config"
,
"DbrxConfig"
,
"DeepseekVLV2Config"
,
"FM9GConfig"
,
"MPTConfig"
,
"RWConfig"
,
"H2OVLChatConfig"
,
...
...
vllm/transformers_utils/configs/fm9g.py
0 → 100644
View file @
8fb5dea5
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""FM9G model configuration"""
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
FM9G_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{}
class
FM9GConfig
(
PretrainedConfig
):
r
"""
This is the configuration class to store the configuration of a [`FM9GModel`]. It is used to instantiate an FM9G
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the FM9G-7B.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the FM9G model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`FM9GModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
issue](https://github.com/pytorch/pytorch/issues/76232).
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
"""
model_type
=
"fm9g"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
def
__init__
(
self
,
vocab_size
=
32000
,
hidden_size
=
4096
,
intermediate_size
=
11008
,
num_hidden_layers
=
32
,
num_attention_heads
=
32
,
num_key_value_heads
=
None
,
hidden_act
=
"silu"
,
max_position_embeddings
=
2048
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
use_cache
=
True
,
pad_token_id
=
None
,
bos_token_id
=
1
,
eos_token_id
=
2
,
pretraining_tp
=
1
,
tie_word_embeddings
=
True
,
rope_theta
=
10000.0
,
rope_scaling
=
None
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
scale_emb
=
1
,
dim_model_base
=
1
,
scale_depth
=
1
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
pretraining_tp
=
pretraining_tp
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
_rope_scaling_validation
()
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
scale_emb
=
scale_emb
self
.
dim_model_base
=
dim_model_base
self
.
scale_depth
=
scale_depth
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
try
:
import
flash_attn
self
.
_attn_implementation
=
"flash_attention_2"
except
:
pass
def
_rope_scaling_validation
(
self
):
"""
Validate the `rope_scaling` configuration.
"""
if
self
.
rope_scaling
is
None
:
return
if
not
isinstance
(
self
.
rope_scaling
,
dict
)
or
len
(
self
.
rope_scaling
)
!=
2
:
raise
ValueError
(
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f
"got
{
self
.
rope_scaling
}
"
)
rope_scaling_type
=
self
.
rope_scaling
.
get
(
"type"
,
None
)
rope_scaling_factor
=
self
.
rope_scaling
.
get
(
"factor"
,
None
)
if
rope_scaling_type
is
None
or
rope_scaling_type
not
in
[
"linear"
,
"dynamic"
]:
raise
ValueError
(
f
"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got
{
rope_scaling_type
}
"
)
if
rope_scaling_factor
is
None
or
not
isinstance
(
rope_scaling_factor
,
float
)
or
rope_scaling_factor
<=
1.0
:
raise
ValueError
(
f
"`rope_scaling`'s factor field must be a float > 1, got
{
rope_scaling_factor
}
"
)
\ No newline at end of file
vllm/transformers_utils/detokenizer.py
View file @
8fb5dea5
...
...
@@ -14,8 +14,12 @@ from .tokenizer_group import TokenizerGroup
class
Detokenizer
:
"""Provides methods to decode the output of a model into text."""
def
__init__
(
self
,
tokenizer_group
:
TokenizerGroup
):
def
__init__
(
self
,
tokenizer_group
:
TokenizerGroup
,
mode
=
"auto"
):
self
.
mode
=
mode
if
self
.
mode
!=
"cpm"
:
self
.
tokenizer_group
=
tokenizer_group
else
:
self
.
tokenizer
=
tokenizer_group
def
get_tokenizer_for_seq
(
self
,
sequence
:
Sequence
)
->
AnyTokenizer
:
"""Returns the HF tokenizer to use for a given sequence."""
...
...
@@ -44,7 +48,10 @@ class Detokenizer:
# Only prompt, without the generated token.
all_token_ids
=
seq
.
get_token_ids
()
prompt_token_ids
=
all_token_ids
[:
-
1
]
if
self
.
mode
!=
"cpm"
:
tokenizer
=
self
.
get_tokenizer_for_seq
(
seq
)
else
:
tokenizer
=
self
.
tokenizer
prefix_offset
=
0
read_offset
=
0
next_iter_prefix_offset
=
0
...
...
@@ -76,6 +83,7 @@ class Detokenizer:
skip_special_tokens
=
prms
.
skip_special_tokens
,
spaces_between_special_tokens
=
prms
.
spaces_between_special_tokens
,
mode
=
self
.
mode
,
)
sample_logprob
.
decoded_token
=
new_text
...
...
@@ -109,7 +117,10 @@ class Detokenizer:
"""
all_input_ids
=
seq
.
get_token_ids
()
token_id_generated_this_iteration
=
all_input_ids
[
-
1
]
if
self
.
mode
!=
"cpm"
:
tokenizer
=
self
.
get_tokenizer_for_seq
(
seq
)
else
:
tokenizer
=
self
.
tokenizer
# Convert prompt token IDs to tokens if necessary.
# Do it here so that we don't have to repeat this
...
...
@@ -131,6 +142,7 @@ class Detokenizer:
read_offset
=
seq
.
read_offset
,
skip_special_tokens
=
prms
.
skip_special_tokens
,
spaces_between_special_tokens
=
prms
.
spaces_between_special_tokens
,
mode
=
self
.
mode
,
)
# Decode logprobs
...
...
@@ -156,6 +168,7 @@ class Detokenizer:
skip_special_tokens
=
prms
.
skip_special_tokens
,
spaces_between_special_tokens
=
prms
.
spaces_between_special_tokens
,
mode
=
self
.
mode
,
)
sample_logprob
.
decoded_token
=
new_text
...
...
vllm/transformers_utils/detokenizer_utils.py
View file @
8fb5dea5
...
...
@@ -16,6 +16,7 @@ def _convert_tokens_to_string_with_added_encoders(
output_tokens
:
List
[
str
],
skip_special_tokens
:
bool
,
spaces_between_special_tokens
:
bool
,
mode
:
str
,
)
->
str
:
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
...
...
@@ -24,7 +25,10 @@ def _convert_tokens_to_string_with_added_encoders(
# even when the loop body is very simple.
sub_texts
:
List
[
str
]
=
[]
current_sub_text
:
List
[
str
]
=
[]
if
mode
!=
"cpm"
:
all_special_tokens
=
set
(
tokenizer
.
all_special_tokens
)
else
:
all_special_tokens
=
tokenizer
.
_special_token_set
for
token
in
output_tokens
:
if
skip_special_tokens
and
token
in
all_special_tokens
:
continue
...
...
@@ -37,7 +41,10 @@ def _convert_tokens_to_string_with_added_encoders(
else
:
current_sub_text
.
append
(
token
)
if
current_sub_text
:
if
mode
!=
"cpm"
:
sub_text
=
tokenizer
.
convert_tokens_to_string
(
current_sub_text
)
else
:
sub_text
=
tokenizer
.
decode
(
current_sub_text
)
sub_texts
.
append
(
sub_text
)
if
spaces_between_special_tokens
:
return
" "
.
join
(
sub_texts
)
...
...
@@ -104,6 +111,7 @@ def detokenize_incrementally(
read_offset
:
int
,
skip_special_tokens
:
bool
=
False
,
spaces_between_special_tokens
:
bool
=
True
,
mode
:
str
=
"cpm"
,
)
->
Tuple
[
List
[
str
],
str
,
int
,
int
]:
"""Detokenizes the input ids incrementally and returns the new tokens
and the new text.
...
...
@@ -141,7 +149,11 @@ def detokenize_incrementally(
assert
prev_tokens
is
not
None
# If the new token id is out of bounds, return an empty string.
if
0
<=
new_token_id
<
len
(
tokenizer
):
if
mode
==
"cpm"
:
vocab_size
=
tokenizer
.
vocab_size
else
:
vocab_size
=
len
(
tokenizer
)
if
0
<=
new_token_id
<
vocab_size
:
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens
=
tokenizer
.
convert_ids_to_tokens
(
[
new_token_id
],
skip_special_tokens
=
skip_special_tokens
)
...
...
@@ -169,12 +181,14 @@ def detokenize_incrementally(
output_tokens
[
prefix_offset
:
read_offset
],
skip_special_tokens
=
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
mode
=
mode
,
)
new_text
=
_convert_tokens_to_string_with_added_encoders
(
tokenizer
,
output_tokens
[
prefix_offset
:],
skip_special_tokens
=
skip_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
mode
=
mode
,
)
if
len
(
new_text
)
<=
len
(
prefix_text
)
or
new_text
.
endswith
(
"�"
):
...
...
vllm/transformers_utils/tokenizers/__init__.py
View file @
8fb5dea5
...
...
@@ -2,8 +2,10 @@
from
.mistral
import
(
MistralTokenizer
,
maybe_serialize_tool_calls
,
truncate_tool_call_ids
,
validate_request_params
)
from
vllm.transformers_utils.tokenizers.cpm_9g
import
CPM9GTokenizer
__all__
=
[
"MistralTokenizer"
,
"maybe_serialize_tool_calls"
,
"truncate_tool_call_ids"
,
"validate_request_params"
"validate_request_params"
,
"CPM9GTokenizer"
]
vllm/transformers_utils/tokenizers/cpm_9g.py
0 → 100644
View file @
8fb5dea5
import
io
import
json
import
os
from
shutil
import
copyfile
from
typing
import
Any
,
Dict
,
IO
,
List
,
Optional
,
Tuple
import
pkg_resources
import
sentencepiece
as
spm
from
pytrie
import
StringTrie
from
transformers.tokenization_utils
import
AddedToken
,
PreTrainedTokenizer
from
transformers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
VOCAB_FILES_NAMES
=
{
"vocab_file"
:
"vocab.txt"
}
PRETRAINED_VOCAB_FILES_MAP
=
{
"vocab_file"
:
{},
"tokenizer_file"
:
{},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{}
class
CPM9GTokenizer
(
PreTrainedTokenizer
):
"""
CPM9G 分词器类。用于基于字节对编码的分词。
参数:
path (str, 可选): 词汇表文件的路径。
"""
vocab_files_names
=
VOCAB_FILES_NAMES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names
=
[
"input_ids"
,
"attention_mask"
]
def
__init__
(
self
,
vocab_file
:
Optional
[
str
]
=
None
,
unk_token
:
str
=
"<unk>"
,
bos_token
:
str
=
"<s>"
,
eos_token
:
str
=
"</s>"
,
pad_token
:
Optional
[
str
]
=
None
,
sp_model_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
add_bos_token
:
bool
=
True
,
add_eos_token
:
bool
=
False
,
clean_up_tokenization_spaces
:
bool
=
False
,
**
kwargs
,
):
self
.
sp_model_kwargs
=
sp_model_kwargs
or
{}
self
.
vocab_file
=
vocab_file
self
.
add_bos_token
=
add_bos_token
self
.
add_eos_token
=
add_eos_token
self
.
unk_token
=
unk_token
self
.
bos_token
=
bos_token
self
.
eos_token
=
eos_token
self
.
pad_token
=
pad_token
self
.
byte_list
:
List
[
str
]
=
(
[
f
"<0x0
{
hex
(
i
).
upper
()[
2
:]
}
>"
for
i
in
range
(
0x10
)]
+
[
f
"<0x
{
hex
(
i
).
upper
()[
2
:]
}
>"
for
i
in
range
(
0x10
,
0x100
)]
)
self
.
_special_token_set
=
set
([
self
.
unk_token
,
self
.
bos_token
,
self
.
eos_token
]
+
self
.
byte_list
)
if
vocab_file
:
if
'vocab.txt'
not
in
vocab_file
:
all_tokens
=
self
.
load_vocab
(
io
.
FileIO
(
os
.
path
.
join
(
vocab_file
,
VOCAB_FILES_NAMES
[
'vocab_file'
]),
"rb"
))
else
:
all_tokens
=
self
.
load_vocab
(
io
.
FileIO
(
VOCAB_FILES_NAMES
[
'vocab_file'
],
"rb"
))
self
.
encoder
:
Dict
[
str
,
int
]
=
{}
self
.
_special_encoder
:
Dict
[
str
,
int
]
=
{}
for
token
,
token_id
in
all_tokens
.
items
():
if
token
in
self
.
_special_token_set
:
self
.
_special_encoder
[
token
]
=
token_id
else
:
self
.
encoder
[
token
]
=
token_id
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
_byte_decoder
=
{
self
.
_special_encoder
[
token
]:
i
for
i
,
token
in
enumerate
(
self
.
byte_list
)}
self
.
_max_word_len
=
max
([
len
(
x
)
for
x
in
self
.
encoder
.
keys
()])
self
.
_len_word_first
=
{}
for
x
in
self
.
encoder
.
keys
():
if
not
x
[
0
]
in
self
.
_len_word_first
:
self
.
_len_word_first
[
x
[
0
]]
=
1
if
len
(
x
)
>
self
.
_len_word_first
[
x
[
0
]]:
self
.
_len_word_first
[
x
[
0
]]
=
len
(
x
)
self
.
tencoder
=
StringTrie
(
self
.
encoder
)
self
.
_max_token_id
=
self
.
vocab_size
-
1
super
().
__init__
(
bos_token
=
AddedToken
(
bos_token
,
lstrip
=
False
,
rstrip
=
False
),
eos_token
=
AddedToken
(
eos_token
,
lstrip
=
False
,
rstrip
=
False
),
unk_token
=
AddedToken
(
unk_token
,
lstrip
=
False
,
rstrip
=
False
),
pad_token
=
AddedToken
(
pad_token
,
lstrip
=
False
,
rstrip
=
False
)
if
pad_token
else
None
,
add_bos_token
=
add_bos_token
,
add_eos_token
=
add_eos_token
,
sp_model_kwargs
=
self
.
sp_model_kwargs
,
clean_up_tokenization_spaces
=
clean_up_tokenization_spaces
,
**
kwargs
,
)
def
__getstate__
(
self
)
->
Dict
[
str
,
Any
]:
state
=
self
.
__dict__
.
copy
()
state
[
"sp_model"
]
=
None
return
state
def
__setstate__
(
self
,
d
:
Dict
[
str
,
Any
])
->
None
:
self
.
__dict__
=
d
def
load_vocab
(
self
,
fp
:
IO
[
bytes
])
->
Dict
[
str
,
int
]:
"""
加载词汇表文件到字典中。
参数:
fp (IO[bytes]): 词汇表文件指针。
返回:
Dict[str, int]: 词汇表字典。
"""
vocab
:
Dict
[
str
,
int
]
=
{}
reader
=
io
.
TextIOWrapper
(
fp
,
encoding
=
"utf-8"
)
for
token
in
reader
.
readlines
():
token
=
token
.
strip
()
if
len
(
token
)
==
0
:
continue
token
=
json
.
loads
(
token
)
vocab
[
token
]
=
len
(
vocab
)
return
vocab
@
property
def
vocab_size
(
self
)
->
int
:
"""返回词汇表大小"""
return
len
(
self
.
encoder
)
+
len
(
self
.
_special_encoder
)
@
property
def
max_token_id
(
self
)
->
int
:
return
self
.
_max_token_id
@
property
def
eos_id
(
self
):
return
self
.
_special_encoder
[
self
.
eos_token
]
@
property
def
bos_id
(
self
):
return
self
.
_special_encoder
[
self
.
bos_token
]
@
property
def
unk_id
(
self
):
return
self
.
_special_encoder
[
self
.
unk_token
]
def
get_vocab
(
self
)
->
Dict
[
str
,
int
]:
"""返回词汇表作为字典"""
vocab
=
{
self
.
convert_ids_to_tokens
(
i
):
i
for
i
in
range
(
self
.
vocab_size
)}
vocab
.
update
(
self
.
added_tokens_encoder
)
return
vocab
def
_tokenize
(
self
,
text
:
str
)
->
List
[
str
]:
"""返回分词后的字符串"""
output_tokens
:
List
[
str
]
=
[]
st
=
0
while
st
<
len
(
text
):
piece
=
self
.
get_piece
(
text
[
st
:])
output_tokens
.
append
(
piece
)
st
+=
len
(
piece
)
return
output_tokens
def
_convert_token_to_id
(
self
,
token
:
str
)
->
int
:
"""使用词汇表将标记(字符串)转换为 id"""
return
self
.
encoder
.
get
(
token
,
self
.
unk_id
)
def
_convert_id_to_token
(
self
,
index
:
int
)
->
str
:
"""使用词汇表将索引(整数)转换为标记(字符串)"""
return
self
.
decoder
.
get
(
index
,
self
.
unk_token
)
def
convert_tokens_to_string
(
self
,
tokens
:
List
[
str
])
->
str
:
"""将标记序列(字符串)转换为单个字符串"""
current_sub_tokens
:
List
[
str
]
=
[]
out_string
=
""
prev_is_special
=
False
for
i
,
token
in
enumerate
(
tokens
):
if
token
in
self
.
_special_token_set
:
if
not
prev_is_special
and
i
!=
0
:
out_string
+=
" "
out_string
+=
self
.
decode
(
current_sub_tokens
)
+
token
prev_is_special
=
True
current_sub_tokens
=
[]
else
:
current_sub_tokens
.
append
(
token
)
prev_is_special
=
False
out_string
+=
self
.
sp_model
.
decode
(
current_sub_tokens
)
return
out_string
def
save_vocabulary
(
self
,
save_directory
:
str
,
filename_prefix
:
Optional
[
str
]
=
None
)
->
Tuple
[
str
]:
"""
保存词汇表和特殊标记文件到目录。
参数:
save_directory (str): 要保存词汇表的目录。
返回:
Tuple[str]: 保存的文件路径。
"""
if
not
os
.
path
.
isdir
(
save_directory
):
raise
ValueError
(
f
"Vocabulary path (
{
save_directory
}
) should be a directory"
)
out_vocab_file
=
os
.
path
.
join
(
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"vocab_file"
],
)
if
os
.
path
.
abspath
(
self
.
vocab_file
)
!=
os
.
path
.
abspath
(
out_vocab_file
)
and
os
.
path
.
isfile
(
self
.
vocab_file
):
copyfile
(
self
.
vocab_file
,
out_vocab_file
)
elif
not
os
.
path
.
isfile
(
self
.
vocab_file
):
with
open
(
out_vocab_file
,
"wb"
)
as
fi
:
fi
.
write
(
self
.
sp_model
.
serialized_model_proto
())
return
(
out_vocab_file
,
)
def
build_inputs_with_special_tokens
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
bos_token_id
=
[
self
.
bos_token_id
]
if
self
.
add_bos_token
else
[]
eos_token_id
=
[
self
.
eos_token_id
]
if
self
.
add_eos_token
else
[]
output
=
bos_token_id
+
token_ids_0
+
eos_token_id
if
token_ids_1
is
not
None
:
output
=
output
+
bos_token_id
+
token_ids_1
+
eos_token_id
return
output
def
get_special_tokens_mask
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
,
already_has_special_tokens
:
bool
=
False
)
->
List
[
int
]:
"""
获取从未添加特殊标记的标记列表中检索到的序列 id。
在使用分词器的 `prepare_for_model` 方法添加特殊标记时调用此方法。
参数:
token_ids_0 (List[int]): id 列表。
token_ids_1 (List[int], 可选): 序列对的可选第二 id 列表。
already_has_special_tokens (bool, 可选, 默认值为 False):
标记列表是否已使用模型的特殊标记进行格式化。
返回:
List[int]: 一个包含整数(0 或 1)的列表。1 表示特殊标记,0 表示序列标记。
"""
if
already_has_special_tokens
:
return
super
().
get_special_tokens_mask
(
token_ids_0
=
token_ids_0
,
token_ids_1
=
token_ids_1
,
already_has_special_tokens
=
True
,
)
bos_token_id
=
[
1
]
if
self
.
add_bos_token
else
[]
eos_token_id
=
[
1
]
if
self
.
add_eos_token
else
[]
if
token_ids_1
is
None
:
return
bos_token_id
+
([
0
]
*
len
(
token_ids_0
))
+
eos_token_id
return
bos_token_id
+
([
0
]
*
len
(
token_ids_0
))
+
eos_token_id
+
bos_token_id
+
([
0
]
*
len
(
token_ids_1
))
+
eos_token_id
def
create_token_type_ids_from_sequences
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
"""
从传递的两个序列创建掩码,用于序列对分类任务。
参数:
token_ids_0 (List[int]): id 列表。
token_ids_1 (List[int], 可选): 序列对的可选第二 id 列表。
返回:
List[int]: 根据给定序列的标记类型 id 列表。
"""
bos_token_id
=
[
self
.
bos_token_id
]
if
self
.
add_bos_token
else
[]
eos_token_id
=
[
self
.
eos_token_id
]
if
self
.
add_eos_token
else
[]
output
=
[
0
]
*
len
(
bos_token_id
+
token_ids_0
+
eos_token_id
)
if
token_ids_1
is
not
None
:
output
+=
[
1
]
*
len
(
bos_token_id
+
token_ids_1
+
eos_token_id
)
return
output
def
get_piece
(
self
,
text
:
str
)
->
str
:
"""
获取文本中的分词片段。
参数:
text (str): 输入文本。
返回:
str: 分词片段。
"""
if
text
[
0
]
in
self
.
_len_word_first
:
text
=
text
[:
self
.
_len_word_first
[
text
[
0
]]]
len_text
=
len
(
text
)
for
i
in
range
(
len
(
text
)):
sub
=
text
[:
len_text
-
i
]
if
sub
in
self
.
encoder
:
return
sub
return
text
[
0
]
def
encode
(
self
,
text
:
str
)
->
List
[
int
]:
"""
将文本编码为 ID 列表。
参数:
text (str): 输入文本。
返回:
List[int]: 编码后的 ID 列表。
"""
#if len(text) > 20480:
# return [0 for _ in range(20480)]
ret
=
[]
for
x
in
self
.
_tokenize
(
text
):
if
x
in
self
.
encoder
:
ret
.
append
(
self
.
encoder
[
x
])
else
:
ret
.
extend
(
self
.
_encode_unicode
(
x
))
return
ret
def
decode_all
(
self
,
tokens
:
List
[
int
]):
"""Decode ids into a string."""
ret
=
[]
st
=
0
while
st
<
len
(
tokens
):
if
tokens
[
st
]
in
self
.
decoder
:
ret
.
append
(
self
.
decoder
[
tokens
[
st
]])
st
+=
1
elif
tokens
[
st
]
in
self
.
_byte_decoder
:
if
(
st
+
3
<
len
(
tokens
)
and
tokens
[
st
+
1
]
in
self
.
_byte_decoder
and
tokens
[
st
+
2
]
in
self
.
_byte_decoder
and
tokens
[
st
+
3
]
in
self
.
_byte_decoder
):
first_id
=
self
.
_byte_decoder
[
tokens
[
st
]]
plane_id
=
self
.
_byte_decoder
[
tokens
[
st
+
1
]]
row_id
=
self
.
_byte_decoder
[
tokens
[
st
+
2
]]
cell_id
=
self
.
_byte_decoder
[
tokens
[
st
+
3
]]
ret
.
append
(
int
.
to_bytes
(
first_id
<<
24
|
plane_id
<<
16
|
row_id
<<
8
|
cell_id
,
4
,
"big"
).
decode
(
"utf-8"
)
)
st
+=
4
elif
(
st
+
2
<
len
(
tokens
)
and
tokens
[
st
+
1
]
in
self
.
_byte_decoder
and
tokens
[
st
+
2
]
in
self
.
_byte_decoder
):
plane_id
=
self
.
_byte_decoder
[
tokens
[
st
]]
row_id
=
self
.
_byte_decoder
[
tokens
[
st
+
1
]]
cell_id
=
self
.
_byte_decoder
[
tokens
[
st
+
2
]]
ret
.
append
(
int
.
to_bytes
(
plane_id
<<
16
|
row_id
<<
8
|
cell_id
,
3
,
"big"
).
decode
(
"utf-8"
))
st
+=
3
elif
st
+
1
<
len
(
tokens
)
and
tokens
[
st
+
1
]
in
self
.
_byte_decoder
:
row_id
=
self
.
_byte_decoder
[
tokens
[
st
]]
cell_id
=
self
.
_byte_decoder
[
tokens
[
st
+
1
]]
ret
.
append
(
int
.
to_bytes
(
row_id
<<
8
|
cell_id
,
2
,
"big"
).
decode
(
"utf-8"
))
st
+=
2
else
:
cell_id
=
self
.
_byte_decoder
[
tokens
[
st
]]
ret
.
append
(
int
.
to_bytes
(
cell_id
,
1
,
"big"
).
decode
(
"utf-8"
))
st
+=
1
elif
tokens
[
st
]
==
self
.
eos_id
:
ret
.
append
(
self
.
eos_token
)
st
+=
1
elif
tokens
[
st
]
==
self
.
bos_id
:
ret
.
append
(
self
.
bos_token
)
st
+=
1
else
:
ret
.
append
(
self
.
unk_token
)
st
+=
1
return
""
.
join
(
ret
)
def
decode
(
self
,
tokens
:
List
[
int
])
->
str
:
"""
将 ID 列表解码为字符串。
参数:
tokens (List[int]): ID 列表。
返回:
str: 解码后的字符串。
"""
ret
=
[]
st
=
0
while
st
<
len
(
tokens
):
if
tokens
[
st
]
in
self
.
_byte_decoder
:
if
(
st
+
3
<
len
(
tokens
)
and
tokens
[
st
+
1
]
in
self
.
_byte_decoder
and
tokens
[
st
+
2
]
in
self
.
_byte_decoder
and
tokens
[
st
+
3
]
in
self
.
_byte_decoder
):
first_id
=
self
.
_byte_decoder
[
tokens
[
st
]]
plane_id
=
self
.
_byte_decoder
[
tokens
[
st
+
1
]]
row_id
=
self
.
_byte_decoder
[
tokens
[
st
+
2
]]
cell_id
=
self
.
_byte_decoder
[
tokens
[
st
+
3
]]
ret
.
append
(
int
.
to_bytes
(
first_id
<<
24
|
plane_id
<<
16
|
row_id
<<
8
|
cell_id
,
4
,
"big"
).
decode
(
"utf-8"
)
)
st
+=
4
elif
(
st
+
2
<
len
(
tokens
)
and
tokens
[
st
+
1
]
in
self
.
_byte_decoder
and
tokens
[
st
+
2
]
in
self
.
_byte_decoder
):
plane_id
=
self
.
_byte_decoder
[
tokens
[
st
]]
row_id
=
self
.
_byte_decoder
[
tokens
[
st
+
1
]]
cell_id
=
self
.
_byte_decoder
[
tokens
[
st
+
2
]]
ret
.
append
(
int
.
to_bytes
(
plane_id
<<
16
|
row_id
<<
8
|
cell_id
,
3
,
"big"
).
decode
(
"utf-8"
))
st
+=
3
elif
st
+
1
<
len
(
tokens
)
and
tokens
[
st
+
1
]
in
self
.
_byte_decoder
:
row_id
=
self
.
_byte_decoder
[
tokens
[
st
]]
cell_id
=
self
.
_byte_decoder
[
tokens
[
st
+
1
]]
ret
.
append
(
int
.
to_bytes
(
row_id
<<
8
|
cell_id
,
2
,
"big"
).
decode
(
"utf-8"
))
st
+=
2
else
:
cell_id
=
self
.
_byte_decoder
[
tokens
[
st
]]
ret
.
append
(
int
.
to_bytes
(
cell_id
,
1
,
"big"
).
decode
(
"utf-8"
))
st
+=
1
elif
tokens
[
st
]
==
self
.
eos_id
:
ret
.
append
(
self
.
eos_token
)
st
+=
1
elif
tokens
[
st
]
==
self
.
bos_id
:
ret
.
append
(
self
.
bos_token
)
st
+=
1
else
:
ret
.
append
(
tokens
[
st
])
st
+=
1
#else:
# ret.append(self.unk_token)
# st += 1
return
''
.
join
(
ret
)
def
_encode_unicode
(
self
,
token
:
str
)
->
List
[
int
]:
"""
将 Unicode 编码包装到一个辅助函数中。
参数:
token (str): 要编码的标记。
返回:
List[int]: 编码后的 ID 列表。
"""
ids
=
[]
utf8_id
=
token
.
encode
(
"utf-8"
)
for
_id
in
utf8_id
:
ids
.
append
(
self
.
_special_encoder
[
self
.
byte_list
[
_id
]])
return
ids
def
next_token
(
self
,
text
:
str
)
->
Tuple
[
str
,
List
[
int
]]:
"""
快速获取下一个匹配的标记。
参数:
text (str): 输入文本。
返回:
Tuple[str, List[int]]: 匹配的标记及其 ID 列表。
"""
token
,
token_id
=
self
.
tencoder
.
longest_prefix_item
(
text
,
(
None
,
None
))
if
token
is
None
:
token
=
text
[
0
]
token_ids
=
self
.
_encode_unicode
(
token
)
else
:
token_ids
=
[
token_id
]
return
token
,
token_ids
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment