Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
65b1cbb1
Unverified
Commit
65b1cbb1
authored
Jun 30, 2025
by
redmoe-moutain
Committed by
GitHub
Jun 29, 2025
Browse files
[Model] support dots1 (#18254)
Signed-off-by:
redmoe-moutain
<
agiredmoe@gmail.com
>
parent
6c9837a7
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
539 additions
and
0 deletions
+539
-0
docs/models/supported_models.md
docs/models/supported_models.md
+1
-0
tests/models/registry.py
tests/models/registry.py
+2
-0
vllm/model_executor/models/dots1.py
vllm/model_executor/models/dots1.py
+535
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
No files found.
docs/models/supported_models.md
View file @
65b1cbb1
...
@@ -329,6 +329,7 @@ Specified using `--task generate`.
...
@@ -329,6 +329,7 @@ Specified using `--task generate`.
|
`DeepseekForCausalLM`
| DeepSeek |
`deepseek-ai/deepseek-llm-67b-base`
,
`deepseek-ai/deepseek-llm-7b-chat`
etc. | | ✅︎ | ✅︎ |
|
`DeepseekForCausalLM`
| DeepSeek |
`deepseek-ai/deepseek-llm-67b-base`
,
`deepseek-ai/deepseek-llm-7b-chat`
etc. | | ✅︎ | ✅︎ |
|
`DeepseekV2ForCausalLM`
| DeepSeek-V2 |
`deepseek-ai/DeepSeek-V2`
,
`deepseek-ai/DeepSeek-V2-Chat`
etc. | | ✅︎ | ✅︎ |
|
`DeepseekV2ForCausalLM`
| DeepSeek-V2 |
`deepseek-ai/DeepSeek-V2`
,
`deepseek-ai/DeepSeek-V2-Chat`
etc. | | ✅︎ | ✅︎ |
|
`DeepseekV3ForCausalLM`
| DeepSeek-V3 |
`deepseek-ai/DeepSeek-V3-Base`
,
`deepseek-ai/DeepSeek-V3`
etc. | | ✅︎ | ✅︎ |
|
`DeepseekV3ForCausalLM`
| DeepSeek-V3 |
`deepseek-ai/DeepSeek-V3-Base`
,
`deepseek-ai/DeepSeek-V3`
etc. | | ✅︎ | ✅︎ |
|
`Dots1ForCausalLM`
| dots.llm1 |
`rednote-hilab/dots.llm1.base`
,
`rednote-hilab/dots.llm1.inst`
etc. | | ✅︎ | ✅︎ |
|
`ExaoneForCausalLM`
| EXAONE-3 |
`LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`ExaoneForCausalLM`
| EXAONE-3 |
`LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`FalconForCausalLM`
| Falcon |
`tiiuae/falcon-7b`
,
`tiiuae/falcon-40b`
,
`tiiuae/falcon-rw-7b`
, etc. | | ✅︎ | ✅︎ |
|
`FalconForCausalLM`
| Falcon |
`tiiuae/falcon-7b`
,
`tiiuae/falcon-40b`
,
`tiiuae/falcon-rw-7b`
, etc. | | ✅︎ | ✅︎ |
|
`FalconMambaForCausalLM`
| FalconMamba |
`tiiuae/falcon-mamba-7b`
,
`tiiuae/falcon-mamba-7b-instruct`
, etc. | | ✅︎ | ✅︎ |
|
`FalconMambaForCausalLM`
| FalconMamba |
`tiiuae/falcon-mamba-7b`
,
`tiiuae/falcon-mamba-7b-instruct`
, etc. | | ✅︎ | ✅︎ |
...
...
tests/models/registry.py
View file @
65b1cbb1
...
@@ -268,6 +268,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -268,6 +268,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Zamba2ForCausalLM"
:
_HfExamplesInfo
(
"Zyphra/Zamba2-7B-instruct"
),
"Zamba2ForCausalLM"
:
_HfExamplesInfo
(
"Zyphra/Zamba2-7B-instruct"
),
"MiMoForCausalLM"
:
_HfExamplesInfo
(
"XiaomiMiMo/MiMo-7B-RL"
,
"MiMoForCausalLM"
:
_HfExamplesInfo
(
"XiaomiMiMo/MiMo-7B-RL"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"Dots1ForCausalLM"
:
_HfExamplesInfo
(
"rednote-hilab/dots.llm1.inst"
,
min_transformers_version
=
"4.53"
),
# [Encoder-decoder]
# [Encoder-decoder]
"BartModel"
:
_HfExamplesInfo
(
"facebook/bart-base"
),
"BartModel"
:
_HfExamplesInfo
(
"facebook/bart-base"
),
"BartForConditionalGeneration"
:
_HfExamplesInfo
(
"facebook/bart-large-cnn"
),
"BartForConditionalGeneration"
:
_HfExamplesInfo
(
"facebook/bart-large-cnn"
),
...
...
vllm/model_executor/models/dots1.py
0 → 100644
View file @
65b1cbb1
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2025 The rednote-hilab team.
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only dots1 model."""
from
collections.abc
import
Iterable
from
typing
import
Any
,
Optional
,
Union
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
ModelConfig
,
VllmConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
class
Dots1MLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
reduce_results
:
bool
=
True
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
,
prefix
=
f
"
{
prefix
}
.down_proj"
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
Dots1MoE
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
self
.
n_shared_experts
=
config
.
n_shared_experts
if
config
.
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
config
.
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
n_routed_experts
,
bias
=
False
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.gate"
)
if
config
.
topk_method
==
"noaux_tc"
:
self
.
gate
.
e_score_correction_bias
=
(
nn
.
Parameter
(
torch
.
empty
(
config
.
n_routed_experts
)))
else
:
self
.
gate
.
e_score_correction_bias
=
None
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
reduce_results
=
False
,
renormalize
=
config
.
norm_topk_prob
,
quant_config
=
quant_config
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
topk_group
=
config
.
topk_group
,
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
config
.
scoring_func
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
)
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
config
.
n_shared_experts
)
self
.
shared_experts
=
Dots1MLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
reduce_results
=
False
,
prefix
=
f
"
{
prefix
}
.shared_experts"
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
self
.
n_shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
class
Dots1Attention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
config
:
PretrainedConfig
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
getattr
(
config
,
"head_dim"
,
hidden_size
//
self
.
total_num_heads
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
attention_bias
=
config
.
attention_bias
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
attention_bias
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
)
self
.
q_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
k_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
=
self
.
q_norm
(
q
.
reshape
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)).
reshape
(
q
.
shape
)
k
=
self
.
k_norm
(
k
.
reshape
(
-
1
,
self
.
num_kv_heads
,
self
.
head_dim
)).
reshape
(
k
.
shape
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
Dots1DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
prefix
:
str
,
model_config
:
ModelConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
layer_idx
=
int
(
prefix
.
split
(
sep
=
'.'
)[
-
1
])
self
.
layer_idx
=
layer_idx
self
.
self_attn
=
Dots1Attention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
config
=
config
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
if
(
config
.
n_routed_experts
is
not
None
and
layer_idx
>=
config
.
first_k_dense_replace
and
layer_idx
%
config
.
moe_layer_freq
==
0
):
self
.
mlp
=
Dots1MoE
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
else
:
self
.
mlp
=
Dots1MLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
class
Dots1Model
(
nn
.
Module
):
fall_back_to_pt_during_load
=
False
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
model_config
=
vllm_config
.
model_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
vocab_size
=
config
.
vocab_size
if
get_pp_group
().
is_first_rank
:
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.embed_tokens"
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
Dots1DecoderLayer
(
config
,
prefix
,
model_config
=
model_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
),
prefix
=
f
"
{
prefix
}
.layers"
)
if
get_pp_group
().
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
()
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
layer
in
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]:
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
,
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
@
support_torch_compile
class
Dots1ForCausalLM
(
nn
.
Module
,
SupportsPP
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
Dots1Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
if
get_pp_group
().
is_last_rank
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
,
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
"hidden_states"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
"residual"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
})
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
n_routed_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
if
((
"mlp.experts."
in
name
)
and
name
not
in
params_dict
):
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
break
else
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/registry.py
View file @
65b1cbb1
...
@@ -52,6 +52,7 @@ _TEXT_GENERATION_MODELS = {
...
@@ -52,6 +52,7 @@ _TEXT_GENERATION_MODELS = {
"DeepseekForCausalLM"
:
(
"deepseek"
,
"DeepseekForCausalLM"
),
"DeepseekForCausalLM"
:
(
"deepseek"
,
"DeepseekForCausalLM"
),
"DeepseekV2ForCausalLM"
:
(
"deepseek_v2"
,
"DeepseekV2ForCausalLM"
),
"DeepseekV2ForCausalLM"
:
(
"deepseek_v2"
,
"DeepseekV2ForCausalLM"
),
"DeepseekV3ForCausalLM"
:
(
"deepseek_v2"
,
"DeepseekV3ForCausalLM"
),
"DeepseekV3ForCausalLM"
:
(
"deepseek_v2"
,
"DeepseekV3ForCausalLM"
),
"Dots1ForCausalLM"
:
(
"dots1"
,
"Dots1ForCausalLM"
),
"ExaoneForCausalLM"
:
(
"exaone"
,
"ExaoneForCausalLM"
),
"ExaoneForCausalLM"
:
(
"exaone"
,
"ExaoneForCausalLM"
),
"FalconForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"FalconForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"Fairseq2LlamaForCausalLM"
:
(
"fairseq2_llama"
,
"Fairseq2LlamaForCausalLM"
),
"Fairseq2LlamaForCausalLM"
:
(
"fairseq2_llama"
,
"Fairseq2LlamaForCausalLM"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment