Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d76fc11e
Commit
d76fc11e
authored
Jan 28, 2026
by
zhuwenwen
Browse files
Merge tag 'v0.15.0rc1' into v0.15.0rc1-dev
parents
38166ec4
58996f35
Changes
313
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
713 additions
and
263 deletions
+713
-263
vllm/entrypoints/pooling/classify/serving.py
vllm/entrypoints/pooling/classify/serving.py
+57
-89
vllm/entrypoints/pooling/embed/serving.py
vllm/entrypoints/pooling/embed/serving.py
+64
-102
vllm/entrypoints/utils.py
vllm/entrypoints/utils.py
+24
-3
vllm/envs.py
vllm/envs.py
+10
-6
vllm/logging_utils/__init__.py
vllm/logging_utils/__init__.py
+6
-0
vllm/logging_utils/access_log_filter.py
vllm/logging_utils/access_log_filter.py
+144
-0
vllm/lora/ops/triton_ops/fused_moe_lora_op.py
vllm/lora/ops/triton_ops/fused_moe_lora_op.py
+23
-7
vllm/model_executor/layers/fused_moe/all2all_utils.py
vllm/model_executor/layers/fused_moe/all2all_utils.py
+54
-6
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+16
-10
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+14
-6
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+2
-1
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
+24
-8
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
+13
-0
vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py
...cutor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py
+226
-0
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
...model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
+8
-16
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
+6
-0
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+1
-1
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+1
-1
vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
.../model_executor/layers/fused_moe/fused_moe_method_base.py
+16
-0
vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
...del_executor/layers/fused_moe/fused_moe_modular_method.py
+4
-7
No files found.
vllm/entrypoints/pooling/classify/serving.py
View file @
d76fc11e
...
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
http
import
HTTPStatus
from
typing
import
cast
from
typing
import
Final
,
cast
import
jinja2
import
numpy
as
np
...
...
@@ -11,18 +11,8 @@ from fastapi import Request
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.chat_utils
import
ChatTemplateContentFormatOption
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.chat_completion.protocol
import
(
ChatCompletionRequest
,
)
from
vllm.entrypoints.openai.engine.protocol
import
(
ErrorResponse
,
UsageInfo
,
)
from
vllm.entrypoints.openai.engine.serving
import
(
ClassificationServeContext
,
OpenAIServing
,
ServeContext
,
)
from
vllm.entrypoints.openai.engine.protocol
import
ErrorResponse
,
UsageInfo
from
vllm.entrypoints.openai.engine.serving
import
OpenAIServing
,
ServeContext
from
vllm.entrypoints.openai.models.serving
import
OpenAIServingModels
from
vllm.entrypoints.pooling.classify.protocol
import
(
ClassificationChatRequest
,
...
...
@@ -39,60 +29,68 @@ from vllm.pooling_params import PoolingParams
logger
=
init_logger
(
__name__
)
class
ClassificationMixin
(
OpenAIServing
):
chat_template
:
str
|
None
chat_template_content_format
:
ChatTemplateContentFormatOption
trust_request_chat_template
:
bool
ClassificationServeContext
=
ServeContext
[
ClassificationRequest
]
class
ServingClassification
(
OpenAIServing
):
request_id_prefix
=
"classify"
def
__init__
(
self
,
engine_client
:
EngineClient
,
models
:
OpenAIServingModels
,
*
,
request_logger
:
RequestLogger
|
None
,
chat_template
:
str
|
None
=
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
=
"auto"
,
trust_request_chat_template
:
bool
=
False
,
log_error_stack
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
engine_client
=
engine_client
,
models
=
models
,
request_logger
=
request_logger
,
log_error_stack
=
log_error_stack
,
)
self
.
chat_template
=
chat_template
self
.
chat_template_content_format
:
Final
=
chat_template_content_format
self
.
trust_request_chat_template
=
trust_request_chat_template
async
def
_preprocess
(
self
,
ctx
:
ServeContext
,
ctx
:
Classification
ServeContext
,
)
->
ErrorResponse
|
None
:
"""
Process classification inputs: tokenize text, resolve adapters,
and prepare model-specific inputs.
"""
ctx
=
cast
(
ClassificationServeContext
,
ctx
)
try
:
request_obj
=
ctx
.
request
if
isinstance
(
request_obj
,
ClassificationChatRequest
):
chat_request
=
request_obj
messages
=
chat_request
.
messages
trust_request_chat_template
=
getattr
(
self
,
"trust_request_chat_template"
,
False
,
)
ret
=
self
.
_validate_chat_template
(
request_chat_template
=
chat_request
.
chat_template
,
chat_template_kwargs
=
chat_request
.
chat_template_kwargs
,
trust_request_chat_template
=
trust_request_chat_template
,
ctx
.
lora_request
=
self
.
_maybe_get_adapters
(
ctx
.
request
)
if
isinstance
(
ctx
.
request
,
ClassificationChatRequest
):
error_check_ret
=
self
.
_validate_chat_template
(
request_chat_template
=
ctx
.
request
.
chat_template
,
chat_template_kwargs
=
ctx
.
request
.
chat_template_kwargs
,
trust_request_chat_template
=
self
.
trust_request_chat_template
,
)
if
ret
:
return
ret
if
error_check_
ret
:
return
error_check_
ret
_
,
engine_prompts
=
await
self
.
_preprocess_chat
(
c
ast
(
ChatCompletionRequest
,
chat_
request
)
,
c
tx
.
request
,
self
.
renderer
,
messages
,
chat_template
=
(
chat_request
.
chat_template
or
getattr
(
self
,
"chat_template"
,
None
)
),
chat_template_content_format
=
cast
(
ChatTemplateContentFormatOption
,
getattr
(
self
,
"chat_template_content_format"
,
"auto"
),
),
add_generation_prompt
=
chat_request
.
add_generation_prompt
,
continue_final_message
=
chat_request
.
continue_final_message
,
add_special_tokens
=
chat_request
.
add_special_tokens
,
ctx
.
request
.
messages
,
chat_template
=
ctx
.
request
.
chat_template
or
self
.
chat_template
,
chat_template_content_format
=
self
.
chat_template_content_format
,
add_generation_prompt
=
ctx
.
request
.
add_generation_prompt
,
continue_final_message
=
ctx
.
request
.
continue_final_message
,
add_special_tokens
=
ctx
.
request
.
add_special_tokens
,
)
ctx
.
engine_prompts
=
engine_prompts
elif
isinstance
(
request_obj
,
ClassificationCompletionRequest
):
completion_request
=
request_obj
input_data
=
completion_request
.
input
elif
isinstance
(
ctx
.
request
,
ClassificationCompletionRequest
):
input_data
=
ctx
.
request
.
input
if
input_data
in
(
None
,
""
):
return
self
.
create_error_response
(
"Input or messages must be provided"
,
...
...
@@ -106,13 +104,10 @@ class ClassificationMixin(OpenAIServing):
prompt_input
=
cast
(
str
|
list
[
str
],
input_data
)
ctx
.
engine_prompts
=
await
renderer
.
render_prompt
(
prompt_or_prompts
=
prompt_input
,
config
=
self
.
_build_render_config
(
c
ompletion_
request
),
config
=
self
.
_build_render_config
(
c
tx
.
request
),
)
else
:
return
self
.
create_error_response
(
"Invalid classification request type"
,
status_code
=
HTTPStatus
.
BAD_REQUEST
,
)
return
self
.
create_error_response
(
"Invalid classification request type"
)
return
None
...
...
@@ -122,13 +117,14 @@ class ClassificationMixin(OpenAIServing):
def
_build_response
(
self
,
ctx
:
ServeContext
,
ctx
:
Classification
ServeContext
,
)
->
ClassificationResponse
|
ErrorResponse
:
"""
Convert model outputs to a formatted classification response
with probabilities and labels.
"""
ctx
=
cast
(
ClassificationServeContext
,
ctx
)
id2label
=
getattr
(
self
.
model_config
.
hf_config
,
"id2label"
,
{})
items
:
list
[
ClassificationData
]
=
[]
num_prompt_tokens
=
0
...
...
@@ -139,9 +135,7 @@ class ClassificationMixin(OpenAIServing):
probs
=
classify_res
.
probs
predicted_index
=
int
(
np
.
argmax
(
probs
))
label
=
getattr
(
self
.
model_config
.
hf_config
,
"id2label"
,
{}).
get
(
predicted_index
)
label
=
id2label
.
get
(
predicted_index
)
item
=
ClassificationData
(
index
=
idx
,
...
...
@@ -174,32 +168,6 @@ class ClassificationMixin(OpenAIServing):
add_special_tokens
=
request
.
add_special_tokens
,
)
class
ServingClassification
(
ClassificationMixin
):
request_id_prefix
=
"classify"
def
__init__
(
self
,
engine_client
:
EngineClient
,
models
:
OpenAIServingModels
,
*
,
request_logger
:
RequestLogger
|
None
,
chat_template
:
str
|
None
=
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
=
"auto"
,
trust_request_chat_template
:
bool
=
False
,
log_error_stack
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
engine_client
=
engine_client
,
models
=
models
,
request_logger
=
request_logger
,
log_error_stack
=
log_error_stack
,
)
self
.
chat_template
=
chat_template
self
.
chat_template_content_format
=
chat_template_content_format
self
.
trust_request_chat_template
=
trust_request_chat_template
async
def
create_classify
(
self
,
request
:
ClassificationRequest
,
...
...
@@ -215,11 +183,11 @@ class ServingClassification(ClassificationMixin):
request_id
=
request_id
,
)
return
await
s
uper
()
.
handle
(
ctx
)
# type: ignore
return
await
s
elf
.
handle
(
ctx
)
# type: ignore
[return-value]
def
_create_pooling_params
(
self
,
ctx
:
ServeContext
[
Classification
Request
]
,
ctx
:
Classification
ServeContext
,
)
->
PoolingParams
|
ErrorResponse
:
pooling_params
=
super
().
_create_pooling_params
(
ctx
)
if
isinstance
(
pooling_params
,
ErrorResponse
):
...
...
vllm/entrypoints/pooling/embed/serving.py
View file @
d76fc11e
...
...
@@ -6,21 +6,13 @@ from typing import Any, Final, cast
import
torch
from
fastapi
import
Request
from
fastapi.responses
import
Response
from
typing_extensions
import
assert_never
,
override
from
typing_extensions
import
assert_never
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.chat_utils
import
ChatTemplateContentFormatOption
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.engine.protocol
import
(
ErrorResponse
,
UsageInfo
,
)
from
vllm.entrypoints.openai.engine.serving
import
(
EmbeddingServeContext
,
OpenAIServing
,
ServeContext
,
)
from
vllm.entrypoints.openai.engine.protocol
import
ErrorResponse
,
UsageInfo
from
vllm.entrypoints.openai.engine.serving
import
OpenAIServing
,
ServeContext
from
vllm.entrypoints.openai.models.serving
import
OpenAIServingModels
from
vllm.entrypoints.pooling.embed.protocol
import
(
EmbeddingBytesResponse
,
...
...
@@ -33,19 +25,11 @@ from vllm.entrypoints.pooling.embed.protocol import (
from
vllm.entrypoints.renderer
import
RenderConfig
from
vllm.inputs.data
import
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.outputs
import
(
EmbeddingRequestOutput
,
PoolingOutput
,
PoolingRequestOutput
,
RequestOutput
,
)
from
vllm.outputs
import
PoolingOutput
,
PoolingRequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.utils.async_utils
import
merge_async_iterators
from
vllm.utils.collection_utils
import
chunk_list
from
vllm.utils.serial_utils
import
(
EmbedDType
,
EncodingFormat
,
Endianness
,
encode_pooling_bytes
,
encode_pooling_output
,
)
...
...
@@ -53,9 +37,33 @@ from vllm.utils.serial_utils import (
logger
=
init_logger
(
__name__
)
class
EmbeddingMixin
(
OpenAIServing
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
EmbeddingServeContext
=
ServeContext
[
EmbeddingRequest
]
class
OpenAIServingEmbedding
(
OpenAIServing
):
request_id_prefix
=
"embd"
def
__init__
(
self
,
engine_client
:
EngineClient
,
models
:
OpenAIServingModels
,
*
,
request_logger
:
RequestLogger
|
None
,
chat_template
:
str
|
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
,
trust_request_chat_template
:
bool
=
False
,
log_error_stack
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
engine_client
=
engine_client
,
models
=
models
,
request_logger
=
request_logger
,
log_error_stack
=
log_error_stack
,
)
self
.
chat_template
=
chat_template
self
.
chat_template_content_format
:
Final
=
chat_template_content_format
self
.
trust_request_chat_template
=
trust_request_chat_template
pooler_config
=
self
.
model_config
.
pooler_config
...
...
@@ -69,32 +77,41 @@ class EmbeddingMixin(OpenAIServing):
else
None
)
@
override
async
def
_preprocess
(
self
,
ctx
:
ServeContext
,
ctx
:
Embedding
ServeContext
,
)
->
ErrorResponse
|
None
:
ctx
=
cast
(
EmbeddingServeContext
,
ctx
)
try
:
ctx
.
lora_request
=
self
.
_maybe_get_adapters
(
ctx
.
request
)
if
isinstance
(
ctx
.
request
,
EmbeddingChatRequest
):
error_check_ret
=
self
.
_validate_chat_template
(
request_chat_template
=
ctx
.
request
.
chat_template
,
chat_template_kwargs
=
ctx
.
request
.
chat_template_kwargs
,
trust_request_chat_template
=
self
.
trust_request_chat_template
,
)
if
error_check_ret
is
not
None
:
return
error_check_ret
_
,
ctx
.
engine_prompts
=
await
self
.
_preprocess_chat
(
ctx
.
request
,
self
.
renderer
,
ctx
.
request
.
messages
,
chat_template
=
ctx
.
request
.
chat_template
or
ctx
.
chat_template
,
chat_template_content_format
=
ctx
.
chat_template_content_format
,
chat_template
=
ctx
.
request
.
chat_template
or
self
.
chat_template
,
chat_template_content_format
=
self
.
chat_template_content_format
,
add_generation_prompt
=
ctx
.
request
.
add_generation_prompt
,
continue_final_message
=
ctx
.
request
.
continue_final_message
,
add_special_tokens
=
ctx
.
request
.
add_special_tokens
,
)
el
se
:
el
if
isinstance
(
ctx
.
request
,
EmbeddingCompletionRequest
)
:
renderer
=
self
.
_get_completion_renderer
()
ctx
.
engine_prompts
=
await
renderer
.
render_prompt
(
prompt_or_prompts
=
ctx
.
request
.
input
,
config
=
self
.
_build_render_config
(
ctx
.
request
),
)
else
:
return
self
.
create_error_response
(
"Invalid classification request type"
)
return
None
except
(
ValueError
,
TypeError
)
as
e
:
logger
.
exception
(
"Error in preprocessing prompt inputs"
)
...
...
@@ -113,16 +130,15 @@ class EmbeddingMixin(OpenAIServing):
add_special_tokens
=
request
.
add_special_tokens
,
)
@
override
def
_build_response
(
self
,
ctx
:
ServeContext
,
)
->
EmbeddingResponse
|
Response
|
ErrorResponse
:
final_res_batch_checked
=
cast
(
list
[
PoolingRequestOutput
],
ctx
.
final_res_batch
)
ctx
:
Embedding
ServeContext
,
)
->
EmbeddingResponse
|
EmbeddingBytes
Response
|
ErrorResponse
:
final_res_batch_checked
=
ctx
.
final_res_batch
encoding_format
:
EncodingFormat
=
ctx
.
request
.
encoding_format
embed_dtype
:
EmbedDType
=
ctx
.
request
.
embed_dtype
endianness
:
Endianness
=
ctx
.
request
.
endianness
encoding_format
=
ctx
.
request
.
encoding_format
embed_dtype
=
ctx
.
request
.
embed_dtype
endianness
=
ctx
.
request
.
endianness
def
encode_float_base64
():
items
:
list
[
EmbeddingResponseData
]
=
[]
...
...
@@ -203,8 +219,8 @@ class EmbeddingMixin(OpenAIServing):
self
,
ctx
:
EmbeddingServeContext
,
token_ids
:
list
[
int
],
pooling_params
,
trace_headers
,
pooling_params
:
PoolingParams
,
trace_headers
:
Mapping
[
str
,
str
]
|
None
,
prompt_idx
:
int
,
)
->
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]:
"""Process a single prompt using chunked processing."""
...
...
@@ -246,7 +262,7 @@ class EmbeddingMixin(OpenAIServing):
def
_validate_input
(
self
,
request
,
request
:
object
,
input_ids
:
list
[
int
],
input_text
:
str
,
)
->
TokensPrompt
:
...
...
@@ -326,7 +342,7 @@ class EmbeddingMixin(OpenAIServing):
pooling_params
:
PoolingParams
,
trace_headers
:
Mapping
[
str
,
str
]
|
None
,
prompt_index
:
int
,
)
->
AsyncGenerator
[
RequestOutput
|
PoolingRequestOutput
,
None
]:
)
->
AsyncGenerator
[
PoolingRequestOutput
,
None
]:
"""Create a generator for a single prompt using standard processing."""
request_id_item
=
f
"
{
ctx
.
request_id
}
-
{
prompt_index
}
"
...
...
@@ -347,7 +363,6 @@ class EmbeddingMixin(OpenAIServing):
priority
=
getattr
(
ctx
.
request
,
"priority"
,
0
),
)
@
override
async
def
_prepare_generators
(
self
,
ctx
:
ServeContext
,
...
...
@@ -363,9 +378,7 @@ class EmbeddingMixin(OpenAIServing):
return
await
super
().
_prepare_generators
(
ctx
)
# Custom logic for chunked processing
generators
:
list
[
AsyncGenerator
[
RequestOutput
|
PoolingRequestOutput
,
None
]
]
=
[]
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
try
:
trace_headers
=
(
...
...
@@ -419,10 +432,9 @@ class EmbeddingMixin(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
@
override
async
def
_collect_batch
(
self
,
ctx
:
ServeContext
,
ctx
:
Embedding
ServeContext
,
)
->
ErrorResponse
|
None
:
"""Collect and aggregate batch results
with support for chunked processing.
...
...
@@ -431,7 +443,6 @@ class EmbeddingMixin(OpenAIServing):
minimize memory usage.
For regular requests, collects results normally.
"""
ctx
=
cast
(
EmbeddingServeContext
,
ctx
)
try
:
if
ctx
.
engine_prompts
is
None
:
return
self
.
create_error_response
(
"Engine prompts not available"
)
...
...
@@ -527,12 +538,10 @@ class EmbeddingMixin(OpenAIServing):
except
(
ValueError
,
IndexError
):
prompt_idx
=
result_idx
# Fallback to result_idx
short_prompts_results
[
prompt_idx
]
=
cast
(
PoolingRequestOutput
,
result
)
short_prompts_results
[
prompt_idx
]
=
result
# Finalize aggregated results
final_res_batch
:
list
[
PoolingRequestOutput
|
EmbeddingRequestOutput
]
=
[]
final_res_batch
:
list
[
PoolingRequestOutput
]
=
[]
num_prompts
=
len
(
ctx
.
engine_prompts
)
for
prompt_idx
in
range
(
num_prompts
):
...
...
@@ -580,49 +589,19 @@ class EmbeddingMixin(OpenAIServing):
f
"Failed to aggregate chunks for prompt
{
prompt_idx
}
"
)
elif
prompt_idx
in
short_prompts_results
:
final_res_batch
.
append
(
cast
(
PoolingRequestOutput
,
short_prompts_results
[
prompt_idx
])
)
final_res_batch
.
append
(
short_prompts_results
[
prompt_idx
])
else
:
return
self
.
create_error_response
(
f
"Result not found for prompt
{
prompt_idx
}
"
)
ctx
.
final_res_batch
=
cast
(
list
[
RequestOutput
|
PoolingRequestOutput
],
final_res_batch
)
ctx
.
final_res_batch
=
final_res_batch
return
None
except
Exception
as
e
:
return
self
.
create_error_response
(
str
(
e
))
class
OpenAIServingEmbedding
(
EmbeddingMixin
):
request_id_prefix
=
"embd"
def
__init__
(
self
,
engine_client
:
EngineClient
,
models
:
OpenAIServingModels
,
*
,
request_logger
:
RequestLogger
|
None
,
chat_template
:
str
|
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
,
trust_request_chat_template
:
bool
=
False
,
log_error_stack
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
engine_client
=
engine_client
,
models
=
models
,
request_logger
=
request_logger
,
log_error_stack
=
log_error_stack
,
)
self
.
chat_template
=
chat_template
self
.
chat_template_content_format
:
Final
=
chat_template_content_format
self
.
trust_request_chat_template
=
trust_request_chat_template
async
def
create_embedding
(
self
,
request
:
EmbeddingRequest
,
...
...
@@ -645,16 +624,13 @@ class OpenAIServingEmbedding(EmbeddingMixin):
raw_request
=
raw_request
,
model_name
=
model_name
,
request_id
=
request_id
,
chat_template
=
self
.
chat_template
,
chat_template_content_format
=
self
.
chat_template_content_format
,
)
return
await
s
uper
()
.
handle
(
ctx
)
# type: ignore
return
await
s
elf
.
handle
(
ctx
)
# type: ignore
[return-value]
@
override
def
_create_pooling_params
(
self
,
ctx
:
ServeContext
[
EmbeddingRequest
]
,
ctx
:
Embedding
ServeContext
,
)
->
PoolingParams
|
ErrorResponse
:
pooling_params
=
super
().
_create_pooling_params
(
ctx
)
if
isinstance
(
pooling_params
,
ErrorResponse
):
...
...
@@ -666,17 +642,3 @@ class OpenAIServingEmbedding(EmbeddingMixin):
return
self
.
create_error_response
(
str
(
e
))
return
pooling_params
async
def
_preprocess
(
self
,
ctx
:
ServeContext
,
)
->
ErrorResponse
|
None
:
if
isinstance
(
ctx
.
request
,
EmbeddingChatRequest
):
error_check_ret
=
self
.
_validate_chat_template
(
request_chat_template
=
ctx
.
request
.
chat_template
,
chat_template_kwargs
=
ctx
.
request
.
chat_template_kwargs
,
trust_request_chat_template
=
self
.
trust_request_chat_template
,
)
if
error_check_ret
is
not
None
:
return
error_check_ret
return
await
super
().
_preprocess
(
ctx
)
vllm/entrypoints/utils.py
View file @
d76fc11e
...
...
@@ -17,8 +17,10 @@ from starlette.background import BackgroundTask, BackgroundTasks
from
vllm
import
envs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.inputs
import
EmbedsPrompt
,
TokensPrompt
from
vllm.logger
import
current_formatter_type
,
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
length_from_prompt_token_ids_or_embeds
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
if
TYPE_CHECKING
:
...
...
@@ -32,11 +34,15 @@ if TYPE_CHECKING:
StreamOptions
,
)
from
vllm.entrypoints.openai.models.protocol
import
LoRAModulePath
from
vllm.entrypoints.openai.responses.protocol
import
(
ResponsesRequest
,
)
else
:
ChatCompletionRequest
=
object
CompletionRequest
=
object
StreamOptions
=
object
LoRAModulePath
=
object
ResponsesRequest
=
object
logger
=
init_logger
(
__name__
)
...
...
@@ -211,11 +217,26 @@ def _validate_truncation_size(
def
get_max_tokens
(
max_model_len
:
int
,
request
:
"
Chat
CompletionRequest | CompletionRequest"
,
input_length
:
in
t
,
request
:
"CompletionRequest |
Chat
CompletionRequest
| ResponsesRequest
"
,
prompt
:
TokensPrompt
|
EmbedsPromp
t
,
default_sampling_params
:
dict
,
)
->
int
:
max_tokens
=
getattr
(
request
,
"max_completion_tokens"
,
None
)
or
request
.
max_tokens
# NOTE: Avoid isinstance() for better efficiency
max_tokens
:
int
|
None
=
None
if
max_tokens
is
None
:
# ChatCompletionRequest
max_tokens
=
getattr
(
request
,
"max_completion_tokens"
,
None
)
if
max_tokens
is
None
:
# ResponsesRequest
max_tokens
=
getattr
(
request
,
"max_output_tokens"
,
None
)
if
max_tokens
is
None
:
# CompletionRequest (also a fallback for ChatCompletionRequest)
max_tokens
=
getattr
(
request
,
"max_tokens"
,
None
)
input_length
=
length_from_prompt_token_ids_or_embeds
(
prompt
.
get
(
"prompt_token_ids"
),
# type: ignore[arg-type]
prompt
.
get
(
"prompt_embeds"
),
# type: ignore[arg-type]
)
default_max_tokens
=
max_model_len
-
input_length
max_output_tokens
=
current_platform
.
get_max_output_tokens
(
input_length
)
...
...
vllm/envs.py
View file @
d76fc11e
...
...
@@ -87,6 +87,7 @@ if TYPE_CHECKING:
VLLM_HTTP_TIMEOUT_KEEP_ALIVE
:
int
=
5
# seconds
VLLM_PLUGINS
:
list
[
str
]
|
None
=
None
VLLM_LORA_RESOLVER_CACHE_DIR
:
str
|
None
=
None
VLLM_LORA_RESOLVER_HF_REPO_LIST
:
str
|
None
=
None
# Deprecated env variables for profiling, kept for backward compatibility
# See also vllm/config/profiler.py and `--profiler-config` argument
VLLM_TORCH_CUDA_PROFILE
:
str
|
None
=
None
...
...
@@ -325,16 +326,11 @@ def use_aot_compile() -> bool:
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_is_batch_invariant
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
default_value
=
(
"1"
if
is_torch_equal_or_newer
(
"2.10.0.dev"
)
and
not
disable_compile_cache
()
# Disabling AOT_COMPILE for CPU
# See: https://github.com/vllm-project/vllm/issues/32033
and
not
current_platform
.
is_cpu
()
if
is_torch_equal_or_newer
(
"2.10.0.dev"
)
and
not
disable_compile_cache
()
else
"0"
)
...
...
@@ -823,6 +819,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
),
# Backend for Video IO
# - "opencv": Default backend that uses OpenCV stream buffered backend.
# - "identity": Returns raw video bytes for model processor to handle.
#
# Custom backend implementations can be registered
# via `@VIDEO_LOADER_REGISTRY.register("my_custom_video_loader")` and
...
...
@@ -914,6 +911,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_LORA_RESOLVER_CACHE_DIR"
:
lambda
:
os
.
getenv
(
"VLLM_LORA_RESOLVER_CACHE_DIR"
,
None
),
# A remote HF repo(s) containing one or more LoRA adapters, which
# may be downloaded and leveraged as needed. Only works if plugins
# are enabled and VLLM_ALLOW_RUNTIME_LORA_UPDATING is enabled.
# Values should be comma separated.
"VLLM_LORA_RESOLVER_HF_REPO_LIST"
:
lambda
:
os
.
getenv
(
"VLLM_LORA_RESOLVER_HF_REPO_LIST"
,
None
),
# Enables torch CUDA profiling if set to 1.
# Deprecated, see profiler_config.
"VLLM_TORCH_CUDA_PROFILE"
:
lambda
:
os
.
getenv
(
"VLLM_TORCH_CUDA_PROFILE"
),
...
...
vllm/logging_utils/__init__.py
View file @
d76fc11e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.logging_utils.access_log_filter
import
(
UvicornAccessLogFilter
,
create_uvicorn_log_config
,
)
from
vllm.logging_utils.formatter
import
ColoredFormatter
,
NewLineFormatter
from
vllm.logging_utils.lazy
import
lazy
from
vllm.logging_utils.log_time
import
logtime
...
...
@@ -8,6 +12,8 @@ from vllm.logging_utils.log_time import logtime
__all__
=
[
"NewLineFormatter"
,
"ColoredFormatter"
,
"UvicornAccessLogFilter"
,
"create_uvicorn_log_config"
,
"lazy"
,
"logtime"
,
]
vllm/logging_utils/access_log_filter.py
0 → 100644
View file @
d76fc11e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Access log filter for uvicorn to exclude specific endpoints from logging.
This module provides a logging filter that can be used to suppress access logs
for specific endpoints (e.g., /health, /metrics) to reduce log noise in
production environments.
"""
import
logging
from
urllib.parse
import
urlparse
class
UvicornAccessLogFilter
(
logging
.
Filter
):
"""
A logging filter that excludes access logs for specified endpoint paths.
This filter is designed to work with uvicorn's access logger. It checks
the log record's arguments for the request path and filters out records
matching the excluded paths.
Uvicorn access log format:
'%s - "%s %s HTTP/%s" %d'
(client_addr, method, path, http_version, status_code)
Example:
127.0.0.1:12345 - "GET /health HTTP/1.1" 200
Args:
excluded_paths: A list of URL paths to exclude from logging.
Paths are matched exactly.
Example: ["/health", "/metrics"]
"""
def
__init__
(
self
,
excluded_paths
:
list
[
str
]
|
None
=
None
):
super
().
__init__
()
self
.
excluded_paths
=
set
(
excluded_paths
or
[])
def
filter
(
self
,
record
:
logging
.
LogRecord
)
->
bool
:
"""
Determine if the log record should be logged.
Args:
record: The log record to evaluate.
Returns:
True if the record should be logged, False otherwise.
"""
if
not
self
.
excluded_paths
:
return
True
# This filter is specific to uvicorn's access logs.
if
record
.
name
!=
"uvicorn.access"
:
return
True
# The path is the 3rd argument in the log record's args tuple.
# See uvicorn's access logging implementation for details.
log_args
=
record
.
args
if
isinstance
(
log_args
,
tuple
)
and
len
(
log_args
)
>=
3
:
path_with_query
=
log_args
[
2
]
# Get path component without query string.
if
isinstance
(
path_with_query
,
str
):
path
=
urlparse
(
path_with_query
).
path
if
path
in
self
.
excluded_paths
:
return
False
return
True
def
create_uvicorn_log_config
(
excluded_paths
:
list
[
str
]
|
None
=
None
,
log_level
:
str
=
"info"
,
)
->
dict
:
"""
Create a uvicorn logging configuration with access log filtering.
This function generates a logging configuration dictionary that can be
passed to uvicorn's `log_config` parameter. It sets up the access log
filter to exclude specified paths.
Args:
excluded_paths: List of URL paths to exclude from access logs.
log_level: The log level for uvicorn loggers.
Returns:
A dictionary containing the logging configuration.
Example:
>>> config = create_uvicorn_log_config(["/health", "/metrics"])
>>> uvicorn.run(app, log_config=config)
"""
config
=
{
"version"
:
1
,
"disable_existing_loggers"
:
False
,
"filters"
:
{
"access_log_filter"
:
{
"()"
:
UvicornAccessLogFilter
,
"excluded_paths"
:
excluded_paths
or
[],
},
},
"formatters"
:
{
"default"
:
{
"()"
:
"uvicorn.logging.DefaultFormatter"
,
"fmt"
:
"%(levelprefix)s %(message)s"
,
"use_colors"
:
None
,
},
"access"
:
{
"()"
:
"uvicorn.logging.AccessFormatter"
,
"fmt"
:
'%(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
,
# noqa: E501
},
},
"handlers"
:
{
"default"
:
{
"formatter"
:
"default"
,
"class"
:
"logging.StreamHandler"
,
"stream"
:
"ext://sys.stderr"
,
},
"access"
:
{
"formatter"
:
"access"
,
"class"
:
"logging.StreamHandler"
,
"stream"
:
"ext://sys.stdout"
,
"filters"
:
[
"access_log_filter"
],
},
},
"loggers"
:
{
"uvicorn"
:
{
"handlers"
:
[
"default"
],
"level"
:
log_level
.
upper
(),
"propagate"
:
False
,
},
"uvicorn.error"
:
{
"level"
:
log_level
.
upper
(),
"handlers"
:
[
"default"
],
"propagate"
:
False
,
},
"uvicorn.access"
:
{
"handlers"
:
[
"access"
],
"level"
:
log_level
.
upper
(),
"propagate"
:
False
,
},
},
}
return
config
vllm/lora/ops/triton_ops/fused_moe_lora_op.py
View file @
d76fc11e
...
...
@@ -62,6 +62,7 @@ def _fused_moe_lora_kernel(
num_experts
,
lora_ids
,
adapter_enabled
,
max_loras
,
# <<< PR2: rename, used for masks when grid axis-2 != max_loras
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
...
...
@@ -83,6 +84,7 @@ def _fused_moe_lora_kernel(
num_slice_c
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
USE_B_L2_CACHE
:
tl
.
constexpr
,
# new, enable .ca load for B
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
...
...
@@ -104,10 +106,13 @@ def _fused_moe_lora_kernel(
if
moe_enabled
==
0
:
# Early exit for the no moe lora case.
return
# The grid size on axis 2 is (max_loras + 1) to handle the no-lora case
# (lora_id == -1), but sorted_token_ids and expert_ids are allocated with
# shape (max_loras, ...). Use (num_programs - 1) for correct bounds checking.
max_loras
=
tl
.
num_programs
(
axis
=
2
)
-
1
# The grid's axis-2 dimension is max_loras + 1 to accommodate the -1 sentinel.
# This guard ensures we don't access sorted_token_ids / expert_ids /
# num_tokens_post_padded beyond their allocated bounds if an invalid
# lora_id somehow appears. Although the caller should pass correct
# max_loras, defensive programming prevents accidental out-of-bounds.
if
lora_id
>=
max_loras
:
return
grid_k
=
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
*
SPLIT_K
)
# calculate pid_m,pid_n
...
...
@@ -136,10 +141,11 @@ def _fused_moe_lora_kernel(
cur_b_ptr
=
tl
.
load
(
b_ptr
+
slice_id
).
to
(
tl
.
pointer_type
(
c_ptr
.
dtype
.
element_ty
))
cur_c_ptr
=
c_ptr
+
(
slice_id
%
num_slice_c
)
*
slice_c_size
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
).
to
(
tl
.
int64
))
%
N
# remove modulo wrap-around
offs_bn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
).
to
(
tl
.
int32
)
offs_k
=
pid_sk
*
BLOCK_SIZE_K
+
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
offs_token_id
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
).
to
(
tl
.
int
64
)
offs_token_id
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
).
to
(
tl
.
int
32
)
token_ind
=
stride_tl
*
lora_id
+
offs_token_id
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
token_ind
,
...
...
@@ -176,7 +182,13 @@ def _fused_moe_lora_kernel(
# GDC wait waits for ALL programs in the prior kernel to complete
# before continuing.
# pre-fetch lora weight
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
k_remaining
,
other
=
0.0
)
# add (offs_bn < N) mask; optional .ca for B
b_mask
=
(
offs_k
[:,
None
]
<
k_remaining
)
&
(
offs_bn
[
None
,
:]
<
N
)
if
USE_B_L2_CACHE
:
b
=
tl
.
load
(
b_ptrs
,
mask
=
b_mask
,
other
=
0.0
,
cache_modifier
=
".ca"
)
else
:
b
=
tl
.
load
(
b_ptrs
,
mask
=
b_mask
,
other
=
0.0
)
if
USE_GDC
and
not
IS_PRIMARY
:
tl
.
extra
.
cuda
.
gdc_wait
()
a
=
tl
.
load
(
...
...
@@ -276,6 +288,7 @@ def _fused_moe_lora_shrink(
num_experts
,
lora_ids
,
adapter_enabled
,
lora_a_stacked
[
0
].
shape
[
0
],
qcurr_hidden_states
.
stride
(
0
),
qcurr_hidden_states
.
stride
(
1
),
w1_lora_a_stacked
.
stride
(
0
),
...
...
@@ -292,6 +305,7 @@ def _fused_moe_lora_shrink(
num_slice_c
=
num_slices
,
top_k
=
1
if
mul_routed_weight
else
top_k_num
,
MUL_ROUTED_WEIGHT
=
False
,
USE_B_L2_CACHE
=
True
,
# new
IS_PRIMARY
=
True
,
**
shrink_config
,
)
...
...
@@ -377,6 +391,7 @@ def _fused_moe_lora_expand(
num_experts
,
lora_ids
,
adapter_enabled
,
lora_b_stacked
[
0
].
shape
[
0
],
a_intermediate_cache1
.
stride
(
0
),
a_intermediate_cache1
.
stride
(
1
),
w1_lora_b_stacked
.
stride
(
0
),
...
...
@@ -393,6 +408,7 @@ def _fused_moe_lora_expand(
num_slice_c
=
num_slices
,
top_k
=
1
,
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
USE_B_L2_CACHE
=
True
,
# new
IS_PRIMARY
=
False
,
**
expand_config
,
)
...
...
vllm/model_executor/layers/fused_moe/all2all_utils.py
View file @
d76fc11e
...
...
@@ -7,17 +7,27 @@ import torch
from
vllm.distributed
import
(
get_ep_group
,
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEParallelConfig
,
FusedMoEQuantConfig
,
)
from
vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize
import
(
FlashInferA2APrepareAndFinalize
,
)
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
FusedMoEPrepareAndFinalize
,
)
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNaiveEP
,
MoEPrepareAndFinalizeNoEP
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.import_utils
import
has_deep_ep
,
has_mori
,
has_pplx
logger
=
init_logger
(
__name__
)
if
current_platform
.
is_cuda_alike
():
if
has_pplx
():
from
.pplx_prepare_finalize
import
(
...
...
@@ -70,20 +80,46 @@ def maybe_make_prepare_finalize(
moe
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
|
None
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
allow_new_interface
:
bool
=
False
,
)
->
FusedMoEPrepareAndFinalize
|
None
:
# NOTE(rob): we are migrating each quant_method to hold the MK
# in all cases. The allow_new_interface=False flag allow us to fall
# back to the old method for methods that have not yet been migrated.
#
# In old method:
# * maybe_init_modular_kernel() calls this function. If we are
# using no Dp/Ep or naive all2all, we return None this function
# returns None and no ModularKernelMethod is created. If non-naive
# all2all is used, this returns a PrepareAndFinalize object and
# a ModularKernelMethod is created.
# In new method:
# * maybe_make_prepare_finalize() is called from the oracle. We
# always return a PrepareAndFinalize object and the quant method
# holds the ModularKernel.
if
not
moe
.
moe_parallel_config
.
use_all2all_kernels
:
return
None
if
not
allow_new_interface
:
return
None
# For DP/TP case, fall back to naive P/F.
if
moe
.
moe_parallel_config
.
dp_size
>
1
:
logger
.
info_once
(
"Detected DP deployment with no --enable-expert-parallel. "
"Falling back to AllGather+ReduceScatter dispatch/combine."
)
return
MoEPrepareAndFinalizeNaiveEP
(
is_sequence_parallel
=
moe
.
moe_parallel_config
.
is_sequence_parallel
,
num_dispatchers
=
(
get_ep_group
().
device_communicator
.
all2all_manager
.
world_size
),
)
else
:
return
MoEPrepareAndFinalizeNoEP
()
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
assert
all2all_manager
is
not
None
prepare_finalize
:
FusedMoEPrepareAndFinalize
|
None
=
None
# TODO(rob): update this as part of the MoE refactor.
assert
not
moe
.
use_flashinfer_cutlass_kernels
,
(
"Must be created in modelopt.py or fp8.py"
)
if
moe
.
use_pplx_kernels
:
assert
quant_config
is
not
None
...
...
@@ -203,4 +239,16 @@ def maybe_make_prepare_finalize(
use_fp8_dispatch
=
use_fp8_dispatch
,
)
elif
moe
.
use_fi_all2allv_kernels
:
assert
quant_config
is
not
None
prepare_finalize
=
FlashInferA2APrepareAndFinalize
(
num_dispatchers
=
all2all_manager
.
world_size
,
)
elif
moe
.
use_naive_all2all_kernels
and
allow_new_interface
:
prepare_finalize
=
MoEPrepareAndFinalizeNaiveEP
(
is_sequence_parallel
=
(
moe
.
moe_parallel_config
.
is_sequence_parallel
),
num_dispatchers
=
all2all_manager
.
world_size
,
)
return
prepare_finalize
vllm/model_executor/layers/fused_moe/config.py
View file @
d76fc11e
...
...
@@ -20,7 +20,6 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
has_flashinfer_cutlass_fused_moe
from
vllm.utils.import_utils
import
has_triton_kernels
from
vllm.utils.math_utils
import
cdiv
...
...
@@ -872,6 +871,7 @@ class FusedMoEParallelConfig:
use_ep
:
bool
# whether to use EP or not
all2all_backend
:
str
# all2all backend for MoE communication
is_sequence_parallel
:
bool
# whether sequence parallelism is used
enable_eplb
:
bool
# whether to enable expert load balancing
@
property
...
...
@@ -893,6 +893,12 @@ class FusedMoEParallelConfig:
def
use_deepep_ll_kernels
(
self
):
return
self
.
use_all2all_kernels
and
self
.
all2all_backend
==
"deepep_low_latency"
@
property
def
use_fi_all2allv_kernels
(
self
):
return
(
self
.
use_all2all_kernels
and
self
.
all2all_backend
==
"flashinfer_all2allv"
)
@
property
def
use_batched_activation_format
(
self
):
return
self
.
use_deepep_ll_kernels
or
self
.
use_pplx_kernels
...
...
@@ -1024,6 +1030,7 @@ class FusedMoEParallelConfig:
ep_rank
=
0
,
use_ep
=
False
,
all2all_backend
=
vllm_parallel_config
.
all2all_backend
,
is_sequence_parallel
=
vllm_parallel_config
.
use_sequence_parallel_moe
,
enable_eplb
=
vllm_parallel_config
.
enable_eplb
,
)
# DP + EP / TP + EP / DP + TP + EP
...
...
@@ -1043,6 +1050,7 @@ class FusedMoEParallelConfig:
ep_rank
=
ep_rank
,
use_ep
=
True
,
all2all_backend
=
vllm_parallel_config
.
all2all_backend
,
is_sequence_parallel
=
vllm_parallel_config
.
use_sequence_parallel_moe
,
enable_eplb
=
vllm_parallel_config
.
enable_eplb
,
)
...
...
@@ -1061,6 +1069,7 @@ class FusedMoEParallelConfig:
use_ep
=
False
,
all2all_backend
=
"naive"
,
enable_eplb
=
False
,
is_sequence_parallel
=
False
,
)
...
...
@@ -1155,12 +1164,9 @@ class FusedMoEConfig:
return
self
.
moe_parallel_config
.
use_mori_kernels
@
property
def
use_flashinfer_cutlass_kernels
(
self
):
"""
Whether to use FlashInfer cutlass kernels for NVFP4 MoE.
"""
return
(
envs
.
VLLM_USE_FLASHINFER_MOE_FP4
and
has_flashinfer_cutlass_fused_moe
()
and
envs
.
VLLM_FLASHINFER_MOE_BACKEND
==
"throughput"
)
def
use_fi_all2allv_kernels
(
self
):
return
self
.
moe_parallel_config
.
use_fi_all2allv_kernels
@
property
def
use_naive_all2all_kernels
(
self
):
return
self
.
moe_parallel_config
.
use_naive_all2all_kernels
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
d76fc11e
...
...
@@ -103,7 +103,14 @@ def run_cutlass_moe_fp8(
or
a2_scale
.
size
(
0
)
==
a1q
.
shape
[
0
]
),
"Intermediate scale shape mismatch"
assert
out_dtype
in
[
torch
.
half
,
torch
.
bfloat16
],
"Invalid output dtype"
if
expert_map
is
not
None
:
# NOTE(rob): the expert_map is used for the STANDARD case and
# the batched format is used by the BATCHED case.
# TODO(rob): update the MK interface to only pass the expert_map
# during the STANDARD case to make this clearer across all kernels.
if
use_batched_format
:
assert
expert_num_tokens
is
not
None
else
:
assert
expert_num_tokens
is
None
# We have two modes: batched experts and non-batched experts.
...
...
@@ -379,7 +386,10 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
# needed for STANDARD activation format kernels in DP/EP mode.
# Note that the BATCHED activation format does not use
# the expert map for identifying experts.
return
not
moe_parallel_config
.
use_all2all_kernels
return
not
(
moe_parallel_config
.
use_fi_all2allv_kernels
or
moe_parallel_config
.
use_deepep_ht_kernels
)
def
supports_chunking
(
self
)
->
bool
:
return
True
...
...
@@ -641,10 +651,8 @@ def run_cutlass_moe_fp4(
class
CutlassExpertsFp4
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
@
staticmethod
def
expects_unquantized_inputs
(
moe_config
:
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
)
->
bool
:
@
property
def
expects_unquantized_inputs
(
self
)
->
bool
:
return
True
@
staticmethod
...
...
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
View file @
d76fc11e
...
...
@@ -148,7 +148,8 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
return
True
# NOTE(rob): discovered an IMA with this combination. Needs investigation.
return
not
moe_parallel_config
.
use_fi_all2allv_kernels
def
supports_chunking
(
self
)
->
bool
:
return
True
...
...
vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
View file @
d76fc11e
...
...
@@ -103,6 +103,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts
:
int
,
a1_scale
:
torch
.
Tensor
|
None
,
quant_config
:
FusedMoEQuantConfig
,
defer_input_quant
:
bool
,
)
->
Callable
:
has_scales
=
token_scales
is
not
None
...
...
@@ -174,6 +175,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_topk_weights
,
a1_scale
,
quant_config
,
defer_input_quant
=
defer_input_quant
,
)
def
_receiver
(
...
...
@@ -187,6 +189,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_topk_weights
:
torch
.
Tensor
|
None
,
a1_scale
:
torch
.
Tensor
|
None
,
quant_config
:
FusedMoEQuantConfig
,
defer_input_quant
:
bool
,
)
->
mk
.
PrepareResultType
:
if
event
.
event
is
not
None
:
event
.
current_stream_wait
()
...
...
@@ -221,14 +224,15 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_num_tokens_per_expert_list
,
device
=
expert_x
.
device
)
#
D
ispatch and
Q
uant
# DeepEP kernels only support dispatching block
-quantized
#
activation scales.
#
Dispatch in bfloat16 and quantize afterwards
if
not
quant_config
.
is_block_quantized
:
#
* For non-block quant, d
ispatch
in b16
and
q
uant
ize now as
#
DeepEP kernels only support dispatching block
scales.
#
* For expert kernels that require unquantized inputs,
#
defer quantization to FusedMoEExpertsPermuteUnpermute.
if
not
quant_config
.
is_block_quantized
and
not
defer_input_quant
:
# Quantize after dispatch.
expert_x_scale
=
None
if
expert_x
.
numel
()
!=
0
:
# TODO: support per_act_token_quant,
expert_x
,
expert_x_scale
=
moe_kernel_quantize_input
(
expert_x
,
a1_scale
,
...
...
@@ -257,6 +261,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map
:
torch
.
Tensor
|
None
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
defer_input_quant
:
bool
=
False
,
)
->
mk
.
ReceiverType
:
if
apply_router_weight_on_input
:
topk
=
topk_ids
.
size
(
1
)
...
...
@@ -266,8 +271,12 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
)
a1
=
a1
*
topk_weights
.
to
(
a1
.
dtype
)
if
quant_config
.
is_block_quantized
:
# Quant and Dispatch
# * DeepEP only supports fp8 block scales so quantize
# before the dispatch for these models.
# * For all other quantization, dispatch after.
# * For expert kernels that require unquantized inputs,
# defer quantization to FusedMoEExpertsPermuteUnpermute.
if
quant_config
.
is_block_quantized
and
not
defer_input_quant
:
a1q
,
a1q_scale
=
moe_kernel_quantize_input
(
a1
,
quant_config
.
a1_scale
,
...
...
@@ -281,7 +290,11 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
else
:
a1q
=
a1
a1q_scale
=
None
a1_post_scale
=
quant_config
.
a1_scale
a1_post_scale
=
(
quant_config
.
a1_gscale
if
quant_config
.
quant_dtype
==
"nvfp4"
else
quant_config
.
a1_scale
)
return
self
.
_do_dispatch
(
tokens
=
a1q
,
...
...
@@ -291,6 +304,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts
=
num_experts
,
a1_scale
=
a1_post_scale
,
quant_config
=
quant_config
,
defer_input_quant
=
defer_input_quant
,
)
def
prepare
(
...
...
@@ -302,6 +316,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map
:
torch
.
Tensor
|
None
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
defer_input_quant
:
bool
=
False
,
)
->
mk
.
PrepareResultType
:
receiver
=
self
.
prepare_async
(
a1
,
...
...
@@ -311,6 +326,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map
,
apply_router_weight_on_input
,
quant_config
,
defer_input_quant
,
)
return
receiver
()
...
...
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
View file @
d76fc11e
...
...
@@ -242,7 +242,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map
:
torch
.
Tensor
|
None
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
defer_input_quant
:
bool
=
False
,
)
->
tuple
[
Callable
,
mk
.
ReceiverType
]:
if
defer_input_quant
:
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
hidden_size
=
a1
.
size
(
1
)
assert
hidden_size
in
self
.
SUPPORTED_HIDDEN_SIZES
,
(
f
"Hidden Size
{
hidden_size
}
not in supported list of hidden sizes"
...
...
@@ -344,7 +351,13 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map
:
torch
.
Tensor
|
None
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
defer_input_quant
:
bool
=
False
,
)
->
mk
.
PrepareResultType
:
if
defer_input_quant
:
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
hook
,
receiver
=
self
.
prepare_async
(
a1
,
topk_weights
,
...
...
vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py
0 → 100644
View file @
d76fc11e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.distributed
import
get_ep_group
from
vllm.distributed.device_communicators.base_device_communicator
import
(
All2AllManagerBase
,
)
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEQuantConfig
from
vllm.model_executor.layers.fused_moe.utils
import
moe_kernel_quantize_input
from
vllm.utils.flashinfer
import
nvfp4_block_scale_interleave
def
get_local_sizes
():
return
get_forward_context
().
dp_metadata
.
get_chunk_sizes_across_dp_rank
()
class
FlashInferA2APrepareAndFinalize
(
mk
.
FusedMoEPrepareAndFinalize
):
"""Base class for FlashInfer MoE prepare and finalize operations."""
def
__init__
(
self
,
num_dispatchers
:
int
=
1
,
):
super
().
__init__
()
self
.
num_dispatchers_
=
num_dispatchers
self
.
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
@
property
def
activation_format
(
self
)
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
Standard
def
max_num_tokens_per_rank
(
self
)
->
int
|
None
:
return
None
def
topk_indices_dtype
(
self
)
->
torch
.
dtype
|
None
:
return
None
def
num_dispatchers
(
self
)
->
int
:
return
self
.
num_dispatchers_
def
output_is_reduced
(
self
)
->
bool
:
return
False
def
_apply_router_weight_on_input
(
self
,
a1
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
)
->
None
:
"""Apply router weight on input if needed."""
if
apply_router_weight_on_input
:
topk
=
topk_ids
.
size
(
1
)
assert
topk
==
1
,
(
"apply_router_weight_on_input is only implemented for topk=1"
)
a1
.
mul_
(
topk_weights
.
to
(
a1
.
dtype
))
def
prepare
(
self
,
a1
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
defer_input_quant
:
bool
=
False
,
)
->
mk
.
PrepareResultType
:
self
.
_apply_router_weight_on_input
(
a1
,
topk_weights
,
topk_ids
,
apply_router_weight_on_input
)
global_num_tokens_cpu
=
get_local_sizes
()
top_k
=
topk_ids
.
size
(
1
)
(
self
.
alltoall_info
,
topk_ids
,
topk_weights
,
a1q
,
a1q_scale
)
=
(
flashinfer_alltoall_dispatch
(
self
.
all2all_manager
,
global_num_tokens_cpu
,
a1
,
quant_config
.
a1_gscale
,
topk_ids
,
topk_weights
,
top_k
,
num_experts
,
quant_config
,
defer_input_quant
=
defer_input_quant
,
)
)
return
a1q
,
a1q_scale
,
None
,
topk_ids
,
topk_weights
def
finalize
(
self
,
output
:
torch
.
Tensor
,
fused_expert_output
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
apply_router_weight_on_input
:
bool
,
weight_and_reduce_impl
:
mk
.
TopKWeightAndReduce
,
)
->
None
:
top_k
=
topk_ids
.
size
(
1
)
token_count
=
output
.
shape
[
0
]
fused_expert_output
=
flashinfer_alltoall_combine
(
self
.
all2all_manager
,
fused_expert_output
,
top_k
=
top_k
,
token_count
=
token_count
,
alltoall_info
=
self
.
alltoall_info
,
)
output
.
copy_
(
fused_expert_output
)
def
flashinfer_alltoall_dispatch
(
all2all_manager
:
All2AllManagerBase
,
global_num_tokens_cpu
:
list
[
int
],
x
:
torch
.
Tensor
,
gs
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
top_k
:
int
,
num_experts
:
int
,
quant_config
:
FusedMoEQuantConfig
,
defer_input_quant
:
bool
=
False
,
):
from
flashinfer.comm.trtllm_alltoall
import
MnnvlMoe
assert
all2all_manager
.
ensure_alltoall_workspace_initialized
(),
(
"FlashInfer AllToAll workspace not available"
)
ep_rank
=
all2all_manager
.
rank
ep_size
=
all2all_manager
.
world_size
max_num_token
=
(
max
(
global_num_tokens_cpu
)
if
global_num_tokens_cpu
is
not
None
else
x
.
shape
[
0
]
)
orig_topk_weights_dtype
=
topk_weights
.
dtype
alltoall_info
,
topk_ids
,
topk_weights
,
_
=
(
MnnvlMoe
.
mnnvl_moe_alltoallv_prepare_without_allgather
(
topk_ids
,
topk_weights
,
None
,
all2all_manager
.
prepare_workspace_tensor
,
max_num_token
,
ep_rank
,
ep_size
,
num_experts
,
num_experts
,
top_k
,
)
)
topk_weights
=
topk_weights
.
view
(
dtype
=
orig_topk_weights_dtype
)
if
not
defer_input_quant
:
x
,
x_sf
=
moe_kernel_quantize_input
(
x
,
gs
,
quant_config
.
quant_dtype
,
quant_config
.
per_act_token_quant
,
quant_config
.
block_shape
,
# NOTE: swizzling pads the scales to multiple of 128
# which makes the scales tensor different shape than
# the hidden states, breaking the A2A kernel. So, we
# delay the swizzling until after the A2A.
is_fp4_scale_swizzled
=
False
,
)
x
=
MnnvlMoe
.
mnnvl_moe_alltoallv
(
x
,
alltoall_info
,
all2all_manager
.
workspace_tensor
,
ep_rank
,
ep_size
,
)
x_sf
=
MnnvlMoe
.
mnnvl_moe_alltoallv
(
x_sf
,
alltoall_info
,
all2all_manager
.
workspace_tensor
,
ep_rank
,
ep_size
,
)
# Swizzle after the A2A if nvfp4.
if
quant_config
.
quant_dtype
==
"nvfp4"
:
if
x_sf
.
element_size
()
==
1
:
x_sf
=
x_sf
.
view
(
torch
.
uint8
)
x_sf
=
nvfp4_block_scale_interleave
(
x_sf
)
else
:
# Block-scale path: pass activations through without quantization
x_sf
=
None
x
=
MnnvlMoe
.
mnnvl_moe_alltoallv
(
x
,
alltoall_info
,
all2all_manager
.
workspace_tensor
,
ep_rank
,
ep_size
,
)
return
alltoall_info
,
topk_ids
,
topk_weights
,
x
,
x_sf
def
flashinfer_alltoall_combine
(
all2all_manager
:
All2AllManagerBase
,
output
:
torch
.
Tensor
,
top_k
:
int
,
token_count
:
int
,
alltoall_info
,
):
from
flashinfer.comm.trtllm_alltoall
import
MnnvlMoe
assert
all2all_manager
.
ensure_alltoall_workspace_initialized
(),
(
"FlashInfer AllToAll workspace not available"
)
return
MnnvlMoe
.
mnnvl_moe_alltoallv_combine
(
output
,
alltoall_info
,
all2all_manager
.
workspace_tensor
,
ep_rank
=
all2all_manager
.
rank
,
ep_size
=
all2all_manager
.
world_size
,
top_k
=
top_k
,
token_count
=
token_count
,
)
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
View file @
d76fc11e
...
...
@@ -78,16 +78,9 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# - skip input activation quantization (kernel applies scaling)
self
.
use_deepseek_fp8_block_scale
=
quant_config
.
is_block_quantized
@
staticmethod
def
expects_unquantized_inputs
(
moe_config
:
mk
.
FusedMoEConfig
,
quant_config
:
FusedMoEQuantConfig
)
->
bool
:
# NVFP4 TP kernels and FP8 block-quantized kernels apply
# input quantization inside FusedMoEPermuteExpertsUnpermute.
return
(
quant_config
.
use_nvfp4_w4a4
and
not
moe_config
.
moe_parallel_config
.
use_all2all_kernels
)
or
(
quant_config
.
use_fp8_w8a8
and
quant_config
.
is_block_quantized
)
@
property
def
expects_unquantized_inputs
(
self
)
->
bool
:
return
self
.
quant_config
.
use_fp8_w8a8
and
self
.
quant_config
.
is_block_quantized
@
staticmethod
def
_supports_current_device
()
->
bool
:
...
...
@@ -144,10 +137,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# FLASHINFER_CUTLASS currently uses its down P/F, which does not
# work with SP. This will be removed in follow up after we get
# rid of the FlashInfer specific P/F function.
return
(
moe_parallel_config
.
dp_size
==
1
or
moe_parallel_config
.
dp_size
==
moe_parallel_config
.
ep_size
)
# TODO: the per-tensor fp8 kernels don't work with MNNVL FI A2As.
return
not
moe_parallel_config
.
is_sequence_parallel
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
...
...
@@ -194,8 +185,9 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
"""
workspace1
=
(
M
,
K
)
workspace2
=
(
0
,)
# For TP, the quantization is fused with fused_moe call.
output_shape
=
(
M
,
K
*
2
if
self
.
quant_dtype
==
"nvfp4"
and
self
.
use_dp
else
K
)
# For NVFP4, the output is stored in a packed int8 format,
# so the actual hidden dim is 2x the size of K here.
output_shape
=
(
M
,
K
*
2
if
self
.
quant_dtype
==
"nvfp4"
else
K
)
# The workspace is determined by `aq`, since it comes after any
# potential communication op and is involved in the expert computation.
return
(
workspace1
,
workspace2
,
output_shape
)
...
...
vllm/model_executor/layers/fused_moe/fused_batched_moe.py
View file @
d76fc11e
...
...
@@ -533,7 +533,13 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map
:
torch
.
Tensor
|
None
,
apply_router_weight_on_input
:
bool
,
quant_config
:
FusedMoEQuantConfig
,
defer_input_quant
:
bool
=
False
,
)
->
mk
.
PrepareResultType
:
if
defer_input_quant
:
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
assert
a1
.
dim
()
==
2
assert
topk_ids
.
dim
()
==
2
assert
topk_ids
.
size
(
0
)
==
a1
.
size
(
0
)
...
...
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
d76fc11e
...
...
@@ -597,7 +597,7 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
return
True
return
not
moe_parallel_config
.
use_fi_all2allv_kernels
@
property
def
quant_type_id
(
self
)
->
int
:
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
d76fc11e
...
...
@@ -2465,7 +2465,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
return
True
return
not
moe_parallel_config
.
use_fi_all2allv_kernels
def
supports_chunking
(
self
)
->
bool
:
return
True
...
...
vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
View file @
d76fc11e
...
...
@@ -5,6 +5,7 @@ from abc import abstractmethod
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
...
...
@@ -26,6 +27,19 @@ class FusedMoEMethodBase(QuantizeMethodBase):
super
().
__init__
()
self
.
moe
:
FusedMoEConfig
=
moe
self
.
moe_quant_config
:
FusedMoEQuantConfig
|
None
=
None
self
.
moe_mk
:
mk
.
FusedMoEModularKernel
|
None
=
None
@
property
def
supports_internal_mk
(
self
)
->
bool
:
# NOTE(rob): temporary attribute to indicate support for
# completed migration to the new internal MK interface.
return
self
.
moe_mk
is
not
None
@
property
def
mk_owns_shared_expert
(
self
)
->
bool
:
# NOTE(rob): temporary attribute to indicate support for
# completed migration to the new internal MK interface.
return
self
.
moe_mk
is
not
None
and
self
.
moe_mk
.
shared_experts
is
not
None
@
abstractmethod
def
create_weights
(
...
...
@@ -91,6 +105,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
@
property
def
topk_indices_dtype
(
self
)
->
torch
.
dtype
|
None
:
if
self
.
moe_mk
is
not
None
:
return
self
.
moe_mk
.
prepare_finalize
.
topk_indices_dtype
()
return
None
@
property
...
...
vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
View file @
d76fc11e
...
...
@@ -30,11 +30,11 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
):
super
().
__init__
(
old_quant_method
.
moe
)
self
.
moe_quant_config
=
old_quant_method
.
moe_quant_config
self
.
fused_experts
=
experts
self
.
moe_mk
=
experts
self
.
disable_expert_map
=
getattr
(
old_quant_method
,
"disable_expert_map"
,
not
self
.
fused_experts
.
supports_expert_map
(),
not
self
.
moe_mk
.
supports_expert_map
(),
)
self
.
old_quant_method
=
old_quant_method
assert
not
self
.
old_quant_method
.
is_monolithic
...
...
@@ -57,10 +57,6 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
),
)
@
property
def
topk_indices_dtype
(
self
)
->
torch
.
dtype
|
None
:
return
self
.
fused_experts
.
prepare_finalize
.
topk_indices_dtype
()
@
property
def
supports_eplb
(
self
)
->
bool
:
return
self
.
old_quant_method
.
supports_eplb
...
...
@@ -96,7 +92,8 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
self
.
fused_experts
(
assert
self
.
moe_mk
is
not
None
return
self
.
moe_mk
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
...
...
Prev
1
2
3
4
5
6
7
8
9
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment