Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
de889cb6
Commit
de889cb6
authored
Feb 05, 2026
by
zhuwenwen
Browse files
sync v0.15.1
parent
c721b814
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
151 additions
and
93 deletions
+151
-93
benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py
benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py
+3
-3
benchmarks/kernels/benchmark_moe_permute_unpermute.py
benchmarks/kernels/benchmark_moe_permute_unpermute.py
+106
-39
tests/kernels/moe/modular_kernel_tools/common.py
tests/kernels/moe/modular_kernel_tools/common.py
+0
-1
tests/kernels/moe/modular_kernel_tools/mk_objects.py
tests/kernels/moe/modular_kernel_tools/mk_objects.py
+0
-1
tests/models/language/pooling/test_token_classification.py
tests/models/language/pooling/test_token_classification.py
+1
-0
tests/models/multimodal/generation/test_common.py
tests/models/multimodal/generation/test_common.py
+16
-17
tests/models/registry.py
tests/models/registry.py
+2
-2
tests/models/test_initialization.py
tests/models/test_initialization.py
+0
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+1
-1
vllm/distributed/device_communicators/cpu_communicator.py
vllm/distributed/device_communicators/cpu_communicator.py
+3
-4
vllm/distributed/device_communicators/cuda_communicator.py
vllm/distributed/device_communicators/cuda_communicator.py
+2
-3
vllm/distributed/device_communicators/mnnvl_compat.py
vllm/distributed/device_communicators/mnnvl_compat.py
+3
-1
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+3
-3
vllm/entrypoints/openai/completion/serving.py
vllm/entrypoints/openai/completion/serving.py
+1
-2
vllm/envs.py
vllm/envs.py
+1
-1
vllm/lora/ops/triton_ops/fused_moe_lora_op.py
vllm/lora/ops/triton_ops/fused_moe_lora_op.py
+1
-3
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+1
-1
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+2
-6
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+3
-2
vllm/model_executor/models/falcon_h1.py
vllm/model_executor/models/falcon_h1.py
+2
-2
No files found.
benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py
View file @
de889cb6
...
...
@@ -197,7 +197,7 @@ def bench_run(
)
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
MoEPrepareAndFinalizeNoEP
(
defer_input_quant
=
True
),
CutlassExpertsFp4
(
make_dummy_moe_config
(),
quant_config
=
quant_config
,
...
...
@@ -242,7 +242,7 @@ def bench_run(
)
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
MoEPrepareAndFinalizeNoEP
(
defer_input_quant
=
True
),
CutlassExpertsFp4
(
make_dummy_moe_config
(),
quant_config
=
quant_config
,
...
...
@@ -520,4 +520,4 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--limit-per-out-ch"
,
nargs
=
"+"
,
type
=
int
,
default
=
[])
args
=
parser
.
parse_args
()
main
(
args
)
main
(
args
)
\ No newline at end of file
benchmarks/kernels/benchmark_moe_permute_unpermute.py
View file @
de889cb6
...
...
@@ -10,6 +10,8 @@ from transformers import AutoConfig
from
vllm.model_executor.layers.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.moe_permute_unpermute
import
(
_moe_permute
,
_moe_unpermute_and_reduce
,
moe_permute
,
moe_unpermute
,
)
...
...
@@ -39,6 +41,7 @@ def benchmark_permute(
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
num_iters
:
int
=
100
,
use_customized_permute
:
bool
=
False
,
)
->
float
:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
...
...
@@ -61,14 +64,29 @@ def benchmark_permute(
input_gating
.
copy_
(
gating_output
[
i
])
def
run
():
moe_permute
(
qhidden_states
,
a1q_scale
=
None
,
topk_ids
=
topk_ids
,
n_expert
=
num_experts
,
expert_map
=
None
,
align_block_size
=
align_block_size
,
)
if
use_customized_permute
:
(
permuted_hidden_states
,
a1q_scale
,
first_token_off
,
inv_perm_idx
,
m_indices
,
)
=
moe_permute
(
qhidden_states
,
a1q_scale
=
None
,
topk_ids
=
topk_ids
,
n_expert
=
num_experts
,
expert_map
=
None
,
align_block_size
=
align_block_size
,
)
else
:
(
permuted_hidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
,
)
=
_moe_permute
(
qhidden_states
,
None
,
topk_ids
,
num_experts
,
None
,
16
)
# JIT compilation & warmup
run
()
...
...
@@ -113,9 +131,11 @@ def benchmark_unpermute(
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
num_iters
:
int
=
100
,
use_customized_permute
:
bool
=
False
,
)
->
float
:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
output_hidden_states
=
torch
.
empty_like
(
hidden_states
)
if
use_fp8_w8a8
:
align_block_size
=
128
# deepgemm needs 128 m aligned block
qhidden_states
,
scale
=
_fp8_quantize
(
hidden_states
,
None
,
None
)
...
...
@@ -130,37 +150,78 @@ def benchmark_unpermute(
)
def
prepare
():
(
permuted_hidden_states
,
_
,
first_token_off
,
inv_perm_idx
,
_
,
)
=
moe_permute
(
qhidden_states
,
a1q_scale
=
None
,
topk_ids
=
topk_ids
,
n_expert
=
num_experts
,
expert_map
=
None
,
align_block_size
=
align_block_size
,
)
# convert to fp16/bf16 as gemm output
return
(
permuted_hidden_states
.
to
(
dtype
),
first_token_off
,
inv_perm_idx
,
)
if
use_customized_permute
:
(
permuted_hidden_states
,
a1q_scale
,
first_token_off
,
inv_perm_idx
,
m_indices
,
)
=
moe_permute
(
qhidden_states
,
a1q_scale
=
None
,
topk_ids
=
topk_ids
,
n_expert
=
num_experts
,
expert_map
=
None
,
align_block_size
=
align_block_size
,
)
# convert to fp16/bf16 as gemm output
return
(
permuted_hidden_states
.
to
(
dtype
),
first_token_off
,
inv_perm_idx
,
m_indices
,
)
else
:
(
permuted_qhidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
,
)
=
_moe_permute
(
qhidden_states
,
None
,
topk_ids
,
num_experts
,
None
,
block_m
=
16
)
# convert to fp16/bf16 as gemm output
return
(
permuted_qhidden_states
.
to
(
dtype
),
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
,
)
def
run
(
input
:
tuple
):
(
permuted_hidden_states
,
first_token_off
,
inv_perm_idx
)
=
input
output
=
torch
.
empty_like
(
hidden_states
)
moe_unpermute
(
output
,
permuted_hidden_states
,
topk_weights
,
inv_perm_idx
,
first_token_off
,
)
if
use_customized_permute
:
(
permuted_hidden_states
,
first_token_off
,
inv_perm_idx
,
m_indices
,
)
=
input
output
=
torch
.
empty_like
(
hidden_states
)
moe_unpermute
(
output
,
permuted_hidden_states
,
topk_weights
,
inv_perm_idx
,
first_token_off
,
)
else
:
(
permuted_hidden_states
,
a1q_scale
,
sorted_token_ids
,
expert_ids
,
inv_perm
,
)
=
input
_moe_unpermute_and_reduce
(
output_hidden_states
,
permuted_hidden_states
,
inv_perm
,
topk_weights
,
True
,
)
# JIT compilation & warmup
input
=
prepare
()
...
...
@@ -215,7 +276,8 @@ class BenchmarkWorker:
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
)
->
tuple
[
float
,
float
]:
use_customized_permute
:
bool
=
False
,
)
->
tuple
[
dict
[
str
,
int
],
float
]:
set_random_seed
(
self
.
seed
)
permute_time
=
benchmark_permute
(
...
...
@@ -227,6 +289,7 @@ class BenchmarkWorker:
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
100
,
use_customized_permute
=
use_customized_permute
,
)
unpermute_time
=
benchmark_unpermute
(
num_tokens
,
...
...
@@ -237,6 +300,7 @@ class BenchmarkWorker:
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
100
,
use_customized_permute
=
use_customized_permute
,
)
return
permute_time
,
unpermute_time
...
...
@@ -283,6 +347,7 @@ def main(args: argparse.Namespace):
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
dtype
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
use_customized_permute
=
args
.
use_customized_permute
if
args
.
batch_size
is
None
:
batch_sizes
=
[
...
...
@@ -334,6 +399,7 @@ def main(args: argparse.Namespace):
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
use_customized_permute
,
)
for
batch_size
in
batch_sizes
],
...
...
@@ -353,9 +419,10 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"auto"
,
"fp8_w8a8"
,
"int8_w8a16"
],
default
=
"auto"
)
parser
.
add_argument
(
"--use-customized-permute"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
main
(
args
)
main
(
args
)
\ No newline at end of file
tests/kernels/moe/modular_kernel_tools/common.py
View file @
de889cb6
...
...
@@ -607,7 +607,6 @@ def make_modular_kernel(
prepare_finalize
=
make_prepare_finalize
(
config
.
prepare_finalize_type
,
config
.
all2all_backend
(),
moe
,
quant_config
)
assert
prepare_finalize
is
not
None
fused_experts
=
make_fused_experts
(
config
.
fused_experts_type
,
...
...
tests/kernels/moe/modular_kernel_tools/mk_objects.py
View file @
de889cb6
...
...
@@ -445,7 +445,6 @@ def make_prepare_finalize(
)
else
:
return
MoEPrepareAndFinalizeNoEP
()
def
_slice
(
rank
:
int
,
num_local_experts
:
int
,
t
:
torch
.
Tensor
)
->
torch
.
Tensor
:
s
=
rank
*
num_local_experts
...
...
tests/models/language/pooling/test_token_classification.py
View file @
de889cb6
...
...
@@ -20,6 +20,7 @@ def test_bert_models(
model
:
str
,
dtype
:
str
,
)
->
None
:
with
vllm_runner
(
model
,
max_model_len
=
None
,
dtype
=
dtype
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
token_classify
(
example_prompts
)
...
...
tests/models/multimodal/generation/test_common.py
View file @
de889cb6
...
...
@@ -573,6 +573,21 @@ VLM_TEST_SETTINGS = {
vllm_output_post_proc
=
model_utils
.
kimiv_vl_vllm_to_hf_output
,
marks
=
[
large_gpu_mark
(
min_gb
=
48
)],
),
"llama4"
:
VLMTestInfo
(
models
=
[
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
],
prompt_formatter
=
lambda
img_prompt
:
f
"<|begin_of_text|><|header_start|>user<|header_end|>
\n\n
{
img_prompt
}
<|eot|><|header_start|>assistant<|header_end|>
\n\n
"
,
# noqa: E501
img_idx_to_prompt
=
lambda
_
:
"<|image|>"
,
test_type
=
(
VLMTestType
.
IMAGE
,
VLMTestType
.
MULTI_IMAGE
),
distributed_executor_backend
=
"mp"
,
image_size_factors
=
[(
0.25
,
0.5
,
1.0
)],
hf_model_kwargs
=
{
"device_map"
:
"auto"
},
max_model_len
=
8192
,
max_num_seqs
=
4
,
dtype
=
"bfloat16"
,
auto_cls
=
AutoModelForImageTextToText
,
tensor_parallel_size
=
4
,
marks
=
multi_gpu_marks
(
num_gpus
=
4
),
),
"llava_next"
:
VLMTestInfo
(
models
=
[
"llava-hf/llava-v1.6-mistral-7b-hf"
],
test_type
=
(
VLMTestType
.
IMAGE
,
VLMTestType
.
CUSTOM_INPUTS
),
...
...
@@ -954,22 +969,6 @@ VLM_TEST_SETTINGS = {
)
],
),
"llama4"
:
VLMTestInfo
(
models
=
[
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
],
prompt_formatter
=
lambda
img_prompt
:
f
"<|begin_of_text|><|header_start|>user<|header_end|>
\n\n
{
img_prompt
}
<|eot|><|header_start|>assistant<|header_end|>
\n\n
"
,
# noqa: E501
img_idx_to_prompt
=
lambda
_
:
"<|image|>"
,
test_type
=
(
VLMTestType
.
IMAGE
,
VLMTestType
.
MULTI_IMAGE
),
distributed_executor_backend
=
"mp"
,
image_size_factors
=
[(.
25
,
0.5
,
1.0
)],
hf_model_kwargs
=
{
"device_map"
:
"auto"
},
max_model_len
=
8192
,
max_num_seqs
=
4
,
dtype
=
"bfloat16"
,
auto_cls
=
AutoModelForImageTextToText
,
tensor_parallel_size
=
8
,
vllm_runner_kwargs
=
{
"gpu_memory_utilization"
:
0.8
},
marks
=
[
large_gpu_mark
(
min_gb
=
80
),
multi_gpu_marks
(
num_gpus
=
8
)],
),
}
...
...
@@ -1322,4 +1321,4 @@ def test_custom_inputs_models_heavy(
test_case
=
test_case
,
hf_runner
=
hf_runner
,
vllm_runner
=
vllm_runner
,
)
)
\ No newline at end of file
tests/models/registry.py
View file @
de889cb6
...
...
@@ -1061,7 +1061,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
"Glm4MoeLiteMTPModel"
:
_HfExamplesInfo
(
"zai-org/GLM-4.7-Flash"
,
speculative_model
=
"zai-org/GLM-4.7-Flash"
,
min_transformers_version
=
"5.0.0"
,
is_available_online
=
False
,
),
"LongCatFlashMTPModel"
:
_HfExamplesInfo
(
"meituan-longcat/LongCat-Flash-Chat"
,
...
...
@@ -1165,4 +1165,4 @@ class HfExampleModels:
HF_EXAMPLE_MODELS
=
HfExampleModels
(
_EXAMPLE_MODELS
)
AUTO_EXAMPLE_MODELS
=
HfExampleModels
(
_AUTOMATIC_CONVERTED_MODELS
)
AUTO_EXAMPLE_MODELS
=
HfExampleModels
(
_AUTOMATIC_CONVERTED_MODELS
)
\ No newline at end of file
tests/models/test_initialization.py
View file @
de889cb6
...
...
@@ -88,7 +88,6 @@ def can_initialize(
[
10
*
GiB_bytes
],
)
scheduler_kv_cache_config
=
generate_scheduler_kv_cache_config
(
kv_cache_configs
)
# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
return
1
,
0
,
scheduler_kv_cache_config
...
...
vllm/_custom_ops.py
View file @
de889cb6
...
...
@@ -2866,7 +2866,7 @@ def onednn_mm(
)
->
torch
.
Tensor
:
output
=
torch
.
empty
((
*
x
.
shape
[
0
:
-
1
],
dnnl_handler
.
n
),
dtype
=
x
.
dtype
)
torch
.
ops
.
_C
.
onednn_mm
(
output
,
x
.
reshape
(
-
1
,
dnnl_handler
.
k
),
bias
,
dnnl_handler
.
handler
_tensor
output
,
x
.
reshape
(
-
1
,
dnnl_handler
.
k
),
bias
,
dnnl_handler
.
handler
)
return
output
...
...
vllm/distributed/device_communicators/cpu_communicator.py
View file @
de889cb6
...
...
@@ -130,20 +130,19 @@ class CpuCommunicator(DeviceCommunicatorBase):
)
->
dict
[
str
,
torch
.
Tensor
|
Any
]:
return
self
.
dist_module
.
recv_tensor_dict
(
src
)
def
dispatch
(
def
dispatch
(
# type: ignore[override]
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch
(
hidden_states
,
router_logits
,
is_sequence_parallel
,
extra_tensors
,
# type: ignore[call-arg]
extra_tensors
,
# type: ignore[call-arg]
)
def
combine
(
...
...
@@ -251,4 +250,4 @@ class _CPUSHMDistributed:
tensor_dict
:
dict
[
str
,
torch
.
Tensor
]
=
{}
for
key
,
size
,
t
in
zip
(
key_list
,
size_list
,
value_list
):
tensor_dict
[
key
]
=
t
.
view
(
size
)
return
tensor_dict
return
tensor_dict
\ No newline at end of file
vllm/distributed/device_communicators/cuda_communicator.py
View file @
de889cb6
...
...
@@ -322,7 +322,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
return
output_list
def
dispatch
(
def
dispatch
(
# type: ignore[override]
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
...
...
@@ -332,7 +332,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
assert
self
.
all2all_manager
is
not
None
return
self
.
all2all_manager
.
dispatch
(
hidden_states
,
...
...
@@ -348,4 +347,4 @@ class CudaCommunicator(DeviceCommunicatorBase):
hidden_states
=
self
.
all2all_manager
.
combine
(
hidden_states
,
is_sequence_parallel
)
return
hidden_states
return
hidden_states
\ No newline at end of file
vllm/distributed/device_communicators/mnnvl_compat.py
View file @
de889cb6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch.distributed
as
dist
from
flashinfer.comm.mnnvl
import
CommBackend
as
CommBackend
...
...
@@ -23,3 +22,6 @@ class CustomCommunicator(CommBackend):
gathered
=
[
None
]
*
self
.
Get_size
()
dist
.
all_gather_object
(
gathered
,
data
,
group
=
self
.
_group
)
return
gathered
def
Split
(
self
,
color
:
int
,
key
:
int
)
->
"CustomCommunicator"
:
return
self
\ No newline at end of file
vllm/entrypoints/openai/api_server.py
View file @
de889cb6
...
...
@@ -930,8 +930,8 @@ async def run_server_worker(
if
args
.
reasoning_parser_plugin
and
len
(
args
.
reasoning_parser_plugin
)
>
3
:
ReasoningParserManager
.
import_reasoning_parser
(
args
.
reasoning_parser_plugin
)
#
Get uvicorn log config (from file or with endpoint filter)
log_config
=
get_uvicorn
_log_config
(
args
)
#
Load logging config for uvicorn if specified
log_config
=
load
_log_config
(
args
.
log_config_file
)
if
log_config
is
not
None
:
uvicorn_kwargs
[
"log_config"
]
=
log_config
...
...
@@ -988,4 +988,4 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
validate_parsed_serve_args
(
args
)
uvloop
.
run
(
run_server
(
args
))
uvloop
.
run
(
run_server
(
args
))
\ No newline at end of file
vllm/entrypoints/openai/completion/serving.py
View file @
de889cb6
...
...
@@ -36,7 +36,6 @@ from vllm.entrypoints.renderer import RenderConfig
from
vllm.entrypoints.utils
import
get_max_tokens
,
should_include_usage
from
vllm.exceptions
import
VLLMValidationError
from
vllm.inputs.data
import
EmbedsPrompt
,
TokensPrompt
,
is_embeds_prompt
from
vllm.inputs.parse
import
get_prompt_components
from
vllm.logger
import
init_logger
from
vllm.logprobs
import
Logprob
from
vllm.outputs
import
RequestOutput
...
...
@@ -744,4 +743,4 @@ class OpenAIServingCompletion(OpenAIServing):
add_special_tokens
=
request
.
add_special_tokens
,
cache_salt
=
request
.
cache_salt
,
needs_detokenization
=
bool
(
request
.
echo
and
not
request
.
return_token_ids
),
)
)
\ No newline at end of file
vllm/envs.py
View file @
de889cb6
...
...
@@ -881,7 +881,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Deprecated, see profiler_config.
"VLLM_TORCH_PROFILER_DIR"
:
lambda
:
os
.
getenv
(
"VLLM_TORCH_PROFILER_DIR"
),
# Enable torch profiler to record shapes if set to 1.
# Deprecated, see profiler_config.
# Deprecated, see profiler_config.
"VLLM_TORCH_PROFILER_RECORD_SHAPES"
:
lambda
:
(
os
.
getenv
(
"VLLM_TORCH_PROFILER_RECORD_SHAPES"
)
),
...
...
vllm/lora/ops/triton_ops/fused_moe_lora_op.py
View file @
de889cb6
...
...
@@ -176,9 +176,7 @@ def _fused_moe_lora_kernel(
# GDC wait waits for ALL programs in the prior kernel to complete
# before continuing.
# pre-fetch lora weight
# add (offs_bn < N) mask; optional .ca for B
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
k_remaining
,
other
=
0.0
)
if
USE_GDC
and
not
IS_PRIMARY
:
tl
.
extra
.
cuda
.
gdc_wait
()
a
=
tl
.
load
(
...
...
@@ -683,4 +681,4 @@ try:
except
AttributeError
:
fused_moe_lora
=
_fused_moe_lora
fused_moe_lora_shrink
=
_fused_moe_lora_shrink
fused_moe_lora_expand
=
_fused_moe_lora_expand
fused_moe_lora_expand
=
_fused_moe_lora_expand
\ No newline at end of file
vllm/model_executor/models/commandr.py
View file @
de889cb6
...
...
@@ -438,7 +438,7 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
Tenso
,
input_ids
:
torch
.
Tenso
r
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
...
vllm/model_executor/models/deepseek_mtp.py
View file @
de889cb6
...
...
@@ -316,11 +316,7 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
# Determine split axis based on op type
# gate/up: ColumnParallel → split along dim 0
# down: RowParallel → split along dim 1
split_dim
=
(
1
if
(
"down_proj.weight"
in
name
and
loaded_weight
.
ndim
>
1
)
else
0
)
split_dim
=
1
if
"down_proj.weight"
in
name
else
0
total
=
loaded_weight
.
shape
[
split_dim
]
assert
total
%
num_chunks
==
0
,
(
f
"Shared expert weight dim
{
total
}
"
...
...
@@ -448,4 +444,4 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
elif
shared_weight
:
# treat shared weights as top level weights
name
=
name
.
replace
(
f
"model.layers.
{
spec_layer
}
."
,
"model."
)
return
name
return
name
\ No newline at end of file
vllm/model_executor/models/deepseek_v2.py
View file @
de889cb6
...
...
@@ -1375,7 +1375,7 @@ class DeepseekV2ForCausalLM(
break
else
:
is_expert_weight
=
False
# Special handling: when AITER fusion_shared_experts is enabled,
# checkpoints may provide a single widened shared_experts tensor
# without explicit expert indices
...
...
@@ -1487,6 +1487,7 @@ class DeepseekV2ForCausalLM(
weight_loader
(
param
,
loaded_weight
)
if
not
is_fusion_moe_shared_experts_layer
:
loaded_params
.
add
(
name
)
return
loaded_params
...
...
@@ -1511,4 +1512,4 @@ def get_spec_layer_idx_from_weight_name(
for
i
in
range
(
config
.
num_nextn_predict_layers
):
if
weight_name
.
startswith
(
f
"model.layers.
{
layer_idx
+
i
}
."
):
return
layer_idx
+
i
return
None
return
None
\ No newline at end of file
vllm/model_executor/models/falcon_h1.py
View file @
de889cb6
...
...
@@ -459,7 +459,7 @@ class FalconH1Model(nn.Module):
def
forward
(
self
,
input_ids
:
torch
.
Tensor
|
None
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
...
...
@@ -678,4 +678,4 @@ class FalconH1ForCausalLM(
if
self
.
tie_word_embeddings
:
loaded_params
.
add
(
"lm_head.weight"
)
return
loaded_params
return
loaded_params
\ No newline at end of file
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment