Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
449d1bce
Unverified
Commit
449d1bce
authored
Feb 06, 2025
by
Michael Goin
Committed by
GitHub
Feb 05, 2025
Browse files
[Misc] Remove duplicated DeepSeek V2/V3 model definition (#12793)
parent
1a6fcad4
Changes
4
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
36 additions
and
821 deletions
+36
-821
vllm/config.py
vllm/config.py
+0
-1
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+35
-13
vllm/model_executor/models/deepseek_v3.py
vllm/model_executor/models/deepseek_v3.py
+0
-806
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-1
No files found.
vllm/config.py
View file @
449d1bce
...
@@ -754,7 +754,6 @@ class ModelConfig:
...
@@ -754,7 +754,6 @@ class ModelConfig:
@
property
@
property
def
is_deepseek_mla
(
self
)
->
bool
:
def
is_deepseek_mla
(
self
)
->
bool
:
# TODO add deepseek_v3
return
(
hasattr
(
self
.
hf_text_config
,
"model_type"
))
\
return
(
hasattr
(
self
.
hf_text_config
,
"model_type"
))
\
and
(
self
.
hf_text_config
.
model_type
in
\
and
(
self
.
hf_text_config
.
model_type
in
\
(
'deepseek_v2'
,
'deepseek_v3'
))
\
(
'deepseek_v2'
,
'deepseek_v3'
))
\
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
449d1bce
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only DeepseekV2 model."""
"""Inference-only DeepseekV2
/DeepseekV3
model."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
...
@@ -115,7 +115,19 @@ class DeepseekV2MoE(nn.Module):
...
@@ -115,7 +115,19 @@ class DeepseekV2MoE(nn.Module):
raise
ValueError
(
f
"Unsupported activation:
{
config
.
hidden_act
}
. "
raise
ValueError
(
f
"Unsupported activation:
{
config
.
hidden_act
}
. "
"Only silu is supported for now."
)
"Only silu is supported for now."
)
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
n_routed_experts
,
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
n_routed_experts
,
bias
=
False
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.gate"
)
if
config
.
topk_method
==
"noaux_tc"
:
self
.
gate
.
e_score_correction_bias
=
nn
.
Parameter
(
torch
.
empty
(
config
.
n_routed_experts
))
else
:
self
.
gate
.
e_score_correction_bias
=
None
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
n_routed_experts
,
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
...
@@ -125,13 +137,10 @@ class DeepseekV2MoE(nn.Module):
...
@@ -125,13 +137,10 @@ class DeepseekV2MoE(nn.Module):
use_grouped_topk
=
True
,
use_grouped_topk
=
True
,
num_expert_group
=
config
.
n_group
,
num_expert_group
=
config
.
n_group
,
topk_group
=
config
.
topk_group
,
topk_group
=
config
.
topk_group
,
prefix
=
f
"
{
prefix
}
.experts"
)
prefix
=
f
"
{
prefix
}
.experts"
,
scoring_func
=
config
.
scoring_func
,
e_score_correction_bias
=
self
.
gate
.
e_score_correction_bias
)
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
n_routed_experts
,
bias
=
False
,
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.gate"
)
if
config
.
n_shared_experts
is
not
None
:
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
(
config
.
moe_intermediate_size
*
intermediate_size
=
(
config
.
moe_intermediate_size
*
config
.
n_shared_experts
)
config
.
n_shared_experts
)
...
@@ -732,6 +741,15 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
...
@@ -732,6 +741,15 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
# TODO(simon): support nextn predict layers
if
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
)
and
self
.
config
.
num_nextn_predict_layers
>
0
:
assert
self
.
config
.
num_nextn_predict_layers
==
1
layer_idx
=
self
.
config
.
num_hidden_layers
if
name
.
startswith
(
f
"model.layers.
{
layer_idx
}
"
):
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
# Skip non-stacked layers and experts (experts handled below).
# Skip non-stacked layers and experts (experts handled below).
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
...
@@ -793,3 +811,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
...
@@ -793,3 +811,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
loaded_params
.
add
(
name
)
return
loaded_params
return
loaded_params
class
DeepseekV3ForCausalLM
(
DeepseekV2ForCausalLM
):
pass
vllm/model_executor/models/deepseek_v3.py
deleted
100644 → 0
View file @
1a6fcad4
This diff is collapsed.
Click to expand it.
vllm/model_executor/models/registry.py
View file @
449d1bce
...
@@ -45,7 +45,7 @@ _TEXT_GENERATION_MODELS = {
...
@@ -45,7 +45,7 @@ _TEXT_GENERATION_MODELS = {
"DeciLMForCausalLM"
:
(
"decilm"
,
"DeciLMForCausalLM"
),
"DeciLMForCausalLM"
:
(
"decilm"
,
"DeciLMForCausalLM"
),
"DeepseekForCausalLM"
:
(
"deepseek"
,
"DeepseekForCausalLM"
),
"DeepseekForCausalLM"
:
(
"deepseek"
,
"DeepseekForCausalLM"
),
"DeepseekV2ForCausalLM"
:
(
"deepseek_v2"
,
"DeepseekV2ForCausalLM"
),
"DeepseekV2ForCausalLM"
:
(
"deepseek_v2"
,
"DeepseekV2ForCausalLM"
),
"DeepseekV3ForCausalLM"
:
(
"deepseek_v
3
"
,
"DeepseekV3ForCausalLM"
),
"DeepseekV3ForCausalLM"
:
(
"deepseek_v
2
"
,
"DeepseekV3ForCausalLM"
),
"ExaoneForCausalLM"
:
(
"exaone"
,
"ExaoneForCausalLM"
),
"ExaoneForCausalLM"
:
(
"exaone"
,
"ExaoneForCausalLM"
),
"FalconForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"FalconForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"Fairseq2LlamaForCausalLM"
:
(
"fairseq2_llama"
,
"Fairseq2LlamaForCausalLM"
),
"Fairseq2LlamaForCausalLM"
:
(
"fairseq2_llama"
,
"Fairseq2LlamaForCausalLM"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment