Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
05ff5be1
Commit
05ff5be1
authored
Sep 14, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.10.2' into v0.10.2-ori
parents
38d80967
26b999c7
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
1306 additions
and
68 deletions
+1306
-68
csrc/attention/mla/sm100_cutlass_mla_kernel.cu
csrc/attention/mla/sm100_cutlass_mla_kernel.cu
+1
-0
tests/kernels/moe/test_mxfp4_moe.py
tests/kernels/moe/test_mxfp4_moe.py
+2
-2
tests/models/language/pooling/test_mm_classifier_conversion.py
.../models/language/pooling/test_mm_classifier_conversion.py
+114
-0
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+90
-52
vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_H100_80GB_HBM3.json
...onfigs/E=256,N=512,device_name=NVIDIA_H100_80GB_HBM3.json
+146
-0
vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H20-3e.json
...ed_moe/configs/E=512,N=128,device_name=NVIDIA_H20-3e.json
+146
-0
vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H100_80GB_HBM3.json
...onfigs/E=512,N=256,device_name=NVIDIA_H100_80GB_HBM3.json
+146
-0
vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H20-3e.json
...ed_moe/configs/E=512,N=256,device_name=NVIDIA_H20-3e.json
+146
-0
vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H100_80GB_HBM3.json
...onfigs/E=512,N=512,device_name=NVIDIA_H100_80GB_HBM3.json
+146
-0
vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H20-3e.json
...ed_moe/configs/E=512,N=512,device_name=NVIDIA_H20-3e.json
+146
-0
vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_H20-3e.json
...sed_moe/configs/E=512,N=64,device_name=NVIDIA_H20-3e.json
+146
-0
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+14
-4
vllm/model_executor/models/adapters.py
vllm/model_executor/models/adapters.py
+43
-2
vllm/model_executor/models/gemma3_mm.py
vllm/model_executor/models/gemma3_mm.py
+5
-1
vllm/model_executor/models/qwen3_moe.py
vllm/model_executor/models/qwen3_moe.py
+6
-3
vllm/model_executor/models/qwen3_next.py
vllm/model_executor/models/qwen3_next.py
+7
-3
vllm/v1/attention/backends/gdn_attn.py
vllm/v1/attention/backends/gdn_attn.py
+2
-1
No files found.
csrc/attention/mla/sm100_cutlass_mla_kernel.cu
View file @
05ff5be1
...
...
@@ -43,6 +43,7 @@ void sm100_cutlass_mla_decode(
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
torch
::
Tensor
const
&
workspace
,
double
sm_scale
,
int64_t
num_kv_splits
)
{
TORCH_CHECK
(
false
,
"CUDA version must be >= 12.4 for cutlass_mla_decode"
);
}
...
...
tests/kernels/moe/test_mxfp4_moe.py
View file @
05ff5be1
...
...
@@ -771,11 +771,11 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe(
w13_ref
=
dequant_mxfp4_batches
(
w13_q
.
view
(
torch
.
uint8
),
w13_scale
.
view
(
torch
.
uint8
).
reshape
(
-
1
)).
to
(
torch
.
float32
).
reshape
(
num_experts
,
2
*
intermediate_size
,
hidden_size
)
num_experts
,
2
*
intermediate_size
,
hidden_size
)
.
to
(
device
)
w2_ref
=
dequant_mxfp4_batches
(
w2_q
.
view
(
torch
.
uint8
),
w2_scale
.
view
(
torch
.
uint8
).
reshape
(
-
1
)).
to
(
torch
.
float32
).
reshape
(
num_experts
,
hidden_size
,
intermediate_size
)
num_experts
,
hidden_size
,
intermediate_size
)
.
to
(
device
)
# Quantize activations for SM100 path and dequantize for reference
hidden_states_q
,
hidden_states_sf
=
mxfp8_quantize
(
hidden_states
,
True
,
32
)
...
...
tests/models/language/pooling/test_mm_classifier_conversion.py
0 → 100644
View file @
05ff5be1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.platforms
import
current_platform
def
test_idefics_multimodal
(
vllm_runner
,
monkeypatch
,
)
->
None
:
if
current_platform
.
is_rocm
():
# ROCm Triton FA does not currently support sliding window attention
# switch to use ROCm CK FA backend
monkeypatch
.
setenv
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"False"
)
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
with
vllm_runner
(
model_name
=
"HuggingFaceM4/Idefics3-8B-Llama3"
,
runner
=
"pooling"
,
task
=
"classify"
,
convert
=
"classify"
,
load_format
=
"dummy"
,
max_model_len
=
512
,
enforce_eager
=
True
,
tensor_parallel_size
=
1
,
disable_log_stats
=
True
,
dtype
=
"bfloat16"
)
as
vllm_model
:
llm
=
vllm_model
.
get_llm
()
outputs
=
llm
.
classify
(
prompts
)
for
output
in
outputs
:
assert
len
(
output
.
outputs
.
probs
)
==
2
def
update_config
(
config
):
config
.
text_config
.
update
({
"architectures"
:
[
"Gemma3ForSequenceClassification"
],
"classifier_from_token"
:
[
"A"
,
"B"
,
"C"
,
"D"
,
"E"
],
"method"
:
"no_post_processing"
,
"id2label"
:
{
"A"
:
"Chair"
,
"B"
:
"Couch"
,
"C"
:
"Table"
,
"D"
:
"Bed"
,
"E"
:
"Cupboard"
},
})
return
config
def
test_gemma_multimodal
(
vllm_runner
,
monkeypatch
,
)
->
None
:
if
current_platform
.
is_rocm
():
# ROCm Triton FA does not currently support sliding window attention
# switch to use ROCm CK FA backend
monkeypatch
.
setenv
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"False"
)
messages
=
[{
"role"
:
"system"
,
"content"
:
"""
You are a helpful assistant. You will be given a product description
which may also include an image. Classify the following product into
one of the categories:
A = chair
B = couch
C = table
D = bed
E = cupboard
You'll answer with exactly one letter (A, B, C, D, or E)."""
},
{
"role"
:
"user"
,
"content"
:
[{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
"https://upload.wikimedia.org/wikipedia/commons/c/c6/Set_of_fourteen_side_chairs_MET_DP110780.jpg"
}
},
{
"type"
:
"text"
,
"text"
:
"A fine 19th century piece of furniture."
}]
}]
with
vllm_runner
(
model_name
=
"google/gemma-3-4b-it"
,
runner
=
"pooling"
,
task
=
"classify"
,
convert
=
"classify"
,
load_format
=
"auto"
,
hf_overrides
=
update_config
,
override_pooler_config
=
{
"pooling_type"
:
"LAST"
},
max_model_len
=
512
,
enforce_eager
=
True
,
tensor_parallel_size
=
1
,
disable_log_stats
=
True
,
dtype
=
"bfloat16"
)
as
vllm_model
:
llm
=
vllm_model
.
get_llm
()
prompts
=
llm
.
preprocess_chat
(
messages
)
result
=
llm
.
classify
(
prompts
)
assert
result
[
0
].
outputs
.
probs
[
0
]
>
0.95
assert
all
(
c
<
0.05
for
c
in
result
[
0
].
outputs
.
probs
[
1
:])
\ No newline at end of file
vllm/entrypoints/llm.py
View file @
05ff5be1
...
...
@@ -703,13 +703,10 @@ class LLM:
return
outputs
def
chat
(
def
preprocess_
chat
(
self
,
messages
:
Union
[
list
[
ChatCompletionMessageParam
],
list
[
list
[
ChatCompletionMessageParam
]]],
sampling_params
:
Optional
[
Union
[
SamplingParams
,
list
[
SamplingParams
]]]
=
None
,
use_tqdm
:
Union
[
bool
,
Callable
[...,
tqdm
]]
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
chat_template
:
Optional
[
str
]
=
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
=
"auto"
,
...
...
@@ -718,56 +715,16 @@ class LLM:
tools
:
Optional
[
list
[
dict
[
str
,
Any
]]]
=
None
,
chat_template_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
list
[
RequestOutpu
t
]:
)
->
list
[
TokensPromp
t
]:
"""
Generate responses for a chat conversation.
The chat conversation is converted into a text prompt using the
tokenizer and calls the [generate][vllm.LLM.generate] method to generate
the responses.
Multi-modal inputs can be passed in the same way you would pass them
to the OpenAI API.
Args:
messages: A list of conversations or a single conversation.
- Each conversation is represented as a list of messages.
- Each message is a dictionary with 'role' and 'content' keys.
sampling_params: The sampling parameters for text generation.
If None, we use the default sampling parameters. When it
is a single value, it is applied to every prompt. When it
is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
use_tqdm: If `True`, shows a tqdm progress bar.
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
chat_template: The template to use for structuring the chat.
If not provided, the model's default chat template will be used.
chat_template_content_format: The format to render message content.
- "string" will render the content as a string.
Example: `"Who are you?"`
- "openai" will render the content as a list of dictionaries,
similar to OpenAI schema.
Example: `[{"type": "text", "text": "Who are you?"}]`
add_generation_prompt: If True, adds a generation template
to each message.
continue_final_message: If True, continues the final message in
the conversation instead of starting a new one. Cannot be
`True` if `add_generation_prompt` is also `True`.
chat_template_kwargs: Additional kwargs to pass to the chat
template.
mm_processor_kwargs: Multimodal processor kwarg overrides for this
chat request. Only used for offline requests.
Generate prompt for a chat conversation. The pre-processed
prompt can then be used as input for the other LLM methods.
Refer to `chat` for a complete description of the arguments.
Returns:
A list of `RequestOutput` objects containing the generated
responses in the same order as the input messages.
A list of `TokensPrompts` objects containing the tokenized
prompt after chat template interpolation, and the
pre-processed multi-modal inputs.
"""
list_of_messages
:
list
[
list
[
ChatCompletionMessageParam
]]
...
...
@@ -800,7 +757,7 @@ class LLM:
)
_chat_template_kwargs
.
update
(
chat_template_kwargs
or
{})
prompts
:
list
[
Union
[
TokensPrompt
,
TextPrompt
]
]
=
[]
prompts
:
list
[
TokensPrompt
]
=
[]
for
msgs
in
list_of_messages
:
# NOTE: _parse_chat_message_content_parts() currently doesn't
...
...
@@ -844,6 +801,87 @@ class LLM:
prompts
.
append
(
prompt
)
return
prompts
def
chat
(
self
,
messages
:
Union
[
list
[
ChatCompletionMessageParam
],
list
[
list
[
ChatCompletionMessageParam
]]],
sampling_params
:
Optional
[
Union
[
SamplingParams
,
list
[
SamplingParams
]]]
=
None
,
use_tqdm
:
Union
[
bool
,
Callable
[...,
tqdm
]]
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
chat_template
:
Optional
[
str
]
=
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
=
"auto"
,
add_generation_prompt
:
bool
=
True
,
continue_final_message
:
bool
=
False
,
tools
:
Optional
[
list
[
dict
[
str
,
Any
]]]
=
None
,
chat_template_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
list
[
RequestOutput
]:
"""
Generate responses for a chat conversation.
The chat conversation is converted into a text prompt using the
tokenizer and calls the [generate][vllm.LLM.generate] method to generate
the responses.
Multi-modal inputs can be passed in the same way you would pass them
to the OpenAI API.
Args:
messages: A list of conversations or a single conversation.
- Each conversation is represented as a list of messages.
- Each message is a dictionary with 'role' and 'content' keys.
sampling_params: The sampling parameters for text generation.
If None, we use the default sampling parameters. When it
is a single value, it is applied to every prompt. When it
is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
use_tqdm: If `True`, shows a tqdm progress bar.
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
chat_template: The template to use for structuring the chat.
If not provided, the model's default chat template will be used.
chat_template_content_format: The format to render message content.
- "string" will render the content as a string.
Example: `"Who are you?"`
- "openai" will render the content as a list of dictionaries,
similar to OpenAI schema.
Example: `[{"type": "text", "text": "Who are you?"}]`
add_generation_prompt: If True, adds a generation template
to each message.
continue_final_message: If True, continues the final message in
the conversation instead of starting a new one. Cannot be
`True` if `add_generation_prompt` is also `True`.
chat_template_kwargs: Additional kwargs to pass to the chat
template.
mm_processor_kwargs: Multimodal processor kwarg overrides for this
chat request. Only used for offline requests.
Returns:
A list of `RequestOutput` objects containing the generated
responses in the same order as the input messages.
"""
prompts
=
self
.
preprocess_chat
(
messages
=
messages
,
lora_request
=
lora_request
,
chat_template
=
chat_template
,
chat_template_content_format
=
chat_template_content_format
,
add_generation_prompt
=
add_generation_prompt
,
continue_final_message
=
continue_final_message
,
tools
=
tools
,
chat_template_kwargs
=
chat_template_kwargs
,
mm_processor_kwargs
=
mm_processor_kwargs
,
)
return
self
.
generate
(
prompts
,
sampling_params
=
sampling_params
,
...
...
vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_H100_80GB_HBM3.json
0 → 100644
View file @
05ff5be1
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
5
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"96"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"512"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"1536"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
}
}
vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H20-3e.json
0 → 100644
View file @
05ff5be1
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"2"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"96"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"512"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"1024"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"1536"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"3072"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"4096"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
2
}
}
vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H100_80GB_HBM3.json
0 → 100644
View file @
05ff5be1
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"2"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"96"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"512"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"1024"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"1536"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"3072"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
3
}
}
vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H20-3e.json
0 → 100644
View file @
05ff5be1
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"96"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"512"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"1024"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"1536"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"3072"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"4096"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
3
}
}
vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H100_80GB_HBM3.json
0 → 100644
View file @
05ff5be1
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"96"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"512"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"1024"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"1536"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"3072"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
}
}
vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H20-3e.json
0 → 100644
View file @
05ff5be1
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"96"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"512"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"1024"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"1536"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"3072"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"4096"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
3
}
}
vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_H20-3e.json
0 → 100644
View file @
05ff5be1
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"2"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"96"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"512"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"1024"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"1536"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"3072"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"4096"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
2
}
}
vllm/model_executor/model_loader/utils.py
View file @
05ff5be1
...
...
@@ -19,10 +19,11 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.linear
import
QKVCrossParallelLinear
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.models.adapters
import
(
as_embedding_model
,
as_reward_model
,
as_seq_cls_model
)
from
vllm.model_executor.models.interfaces
import
SupportsQuant
from
vllm.model_executor.models.adapters
import
(
as_embedding_model
,
as_reward_model
,
as_seq_cls_model
,
try_create_mm_pooling_model_cls
)
from
vllm.model_executor.models.interfaces
import
(
SupportsQuant
,
supports_multimodal
)
from
vllm.utils
import
is_pin_memory_available
logger
=
init_logger
(
__name__
)
...
...
@@ -183,6 +184,15 @@ def get_model_architecture(
"performance may not be optimal."
,
arch
)
convert_type
=
model_config
.
convert_type
if
convert_type
!=
"none"
and
supports_multimodal
(
model_cls
):
logger
.
debug_once
(
"Detected conversion of Multi Modal model."
)
converted
=
try_create_mm_pooling_model_cls
(
model_cls
)
if
converted
is
not
None
:
logger
.
debug_once
(
"Creating wrapper class to forward pooler."
)
return
converted
,
arch
else
:
logger
.
debug_once
(
"Attempting direct conversion."
)
if
convert_type
==
"none"
:
pass
elif
convert_type
==
"embed"
:
...
...
vllm/model_executor/models/adapters.py
View file @
05ff5be1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
ast
import
inspect
from
collections.abc
import
Iterable
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
TypeVar
,
cast
import
torch
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.models.config
import
VerifyAndUpdateConfig
...
...
@@ -129,6 +132,41 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
return
model_name
+
pooling_suffix
def
try_create_mm_pooling_model_cls
(
orig_cls
:
_T
)
->
_T
:
class
CallVisitor
(
ast
.
NodeVisitor
):
def
__init__
(
self
):
self
.
calls
=
[]
def
visit_Call
(
self
,
node
):
if
isinstance
(
node
.
func
,
ast
.
Name
):
self
.
calls
.
append
(
node
.
func
.
id
)
self
.
generic_visit
(
node
)
visitor
=
CallVisitor
()
visitor
.
visit
(
ast
.
parse
(
inspect
.
getsource
(
orig_cls
)))
if
"init_vllm_registered_model"
not
in
visitor
.
calls
:
return
None
class
ModelForPooling
(
orig_cls
,
VllmModelForPooling
):
is_pooling_model
=
True
def
__init__
(
self
,
*
,
vllm_config
:
"VllmConfig"
,
prefix
:
str
=
""
,
**
kwargs
:
Any
,
)
->
None
:
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
self
.
pooler
=
self
.
get_language_model
().
pooler
return
ModelForPooling
# type: ignore
def
_create_pooling_model_cls
(
orig_cls
:
_T
)
->
_T
:
# Lazy import
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
...
...
@@ -399,6 +437,7 @@ def load_weights_using_from_2_way_softmax(
from
vllm.model_executor.models.utils
import
AutoWeightsLoader
model_config
=
model
.
vllm_config
.
model_config
tokens
=
getattr
(
model
.
config
,
"classifier_from_token"
,
[])
tokens
=
cast
(
list
[
int
],
tokens
)
assert
len
(
tokens
)
==
2
...
...
@@ -406,9 +445,10 @@ def load_weights_using_from_2_way_softmax(
if
model
.
config
.
tie_word_embeddings
:
model
.
lm_head
=
model
.
model
.
embed_tokens
else
:
quant_config
=
model
.
vllm_config
.
quant_config
model
.
lm_head
=
ParallelLMHead
(
model
.
config
.
vocab_size
,
model
.
config
.
hidden_size
,
quant_config
=
model
.
quant_config
)
quant_config
=
quant_config
)
loader
=
AutoWeightsLoader
(
model
)
loaded_weights
=
loader
.
load_weights
(
weights
)
...
...
@@ -452,9 +492,10 @@ def load_weights_no_post_processing(model,
if
model
.
config
.
tie_word_embeddings
:
model
.
lm_head
=
model
.
model
.
embed_tokens
else
:
quant_config
=
model
.
vllm_config
.
quant_config
model
.
lm_head
=
ParallelLMHead
(
model
.
config
.
vocab_size
,
model
.
config
.
hidden_size
,
quant_config
=
model
.
quant_config
)
quant_config
=
quant_config
)
loader
=
AutoWeightsLoader
(
model
)
loaded_weights
=
loader
.
load_weights
(
weights
)
...
...
vllm/model_executor/models/gemma3_mm.py
View file @
05ff5be1
...
...
@@ -512,6 +512,10 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
architectures
=
[
"Gemma3ForCausalLM"
],
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
if
hasattr
(
self
.
language_model
,
"logits_processor"
):
# The logits processor can be unset if we're using
# automatic conversion to pooling model.
self
.
language_model
.
logits_processor
.
scale
*=
logit_scale
self
.
make_empty_intermediate_tensors
=
(
...
...
vllm/model_executor/models/qwen3_moe.py
View file @
05ff5be1
...
...
@@ -170,8 +170,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
return
quant_config
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape
=
hidden_states
.
shape
assert
hidden_states
.
dim
(
)
<=
2
,
"Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
is_input_1d
=
hidden_states
.
dim
()
==
1
hidden_dim
=
hidden_states
.
shape
[
-
1
]
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
...
...
@@ -180,7 +181,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
return
final_hidden_states
.
view
(
orig_shape
)
# return to 1d if input is 1d
return
final_hidden_states
.
squeeze
(
0
)
if
is_input_1d
else
\
final_hidden_states
class
Qwen3MoeAttention
(
nn
.
Module
):
...
...
vllm/model_executor/models/qwen3_next.py
View file @
05ff5be1
...
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Qwen3Next model."""
from
collections.abc
import
Iterable
from
itertools
import
islice
from
typing
import
Optional
import
torch
...
...
@@ -917,8 +918,11 @@ class Qwen3NextModel(nn.Module):
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
if
get_pp_group
().
is_last_rank
:
self
.
norm
=
Qwen3NextRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
()
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
...
...
@@ -941,7 +945,7 @@ class Qwen3NextModel(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
layer
in
self
.
layers
:
for
layer
in
islice
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
)
:
hidden_states
,
residual
=
layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
...
...
vllm/v1/attention/backends/gdn_attn.py
View file @
05ff5be1
...
...
@@ -209,7 +209,8 @@ class GDNAttentionMetadataBuilder(
# prepare tensors for cudagraph
if
(
self
.
use_full_cuda_graph
and
num_prefills
==
0
and
num_decodes
==
0
and
num_spec_decodes
<=
self
.
decode_cudagraph_max_bs
):
and
num_spec_decodes
<=
self
.
decode_cudagraph_max_bs
and
m
.
num_actual_tokens
<=
self
.
decode_cudagraph_max_bs
):
num_total_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
m
.
num_actual_tokens
)
batch_size
=
num_total_tokens
//
(
self
.
num_spec
+
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment