Commit e9e95d0f authored by zhuwenwen's avatar zhuwenwen
Browse files

[perf] use optimized topk_softmax + renormalize (lightop)

parent 06e16a27
......@@ -9,6 +9,8 @@ from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
import vllm.envs as envs
from lightop import op as op
def vllm_topk_softmax(
......@@ -18,13 +20,22 @@ def vllm_topk_softmax(
gating_output: torch.Tensor,
renormalize: bool = False,
) -> tuple[torch.Tensor, ...]:
ops.topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
)
if envs.VLLM_USE_TOPK_RENORM and renormalize is True:
op.topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
)
else:
ops.topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
)
return topk_weights, topk_indices
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment