Commit 7ff04b72 authored by zhuwenwen's avatar zhuwenwen
Browse files

add VLLM_USE_TOPK_RENORM to use optimized topk_softmax + renormalize

parent d3824217
......@@ -188,6 +188,7 @@ if TYPE_CHECKING:
VLLM_REJECT_SAMPLE_OPT: bool = False
VLLM_USE_FUSE_SILU_AND_MUL: bool = False
VLLM_USE_OPT_RESHAPE_AND_CACHE: bool = False
VLLM_USE_TOPK_RENORM: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1230,6 +1231,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda:
(os.environ.get("VLLM_USE_OPT_RESHAPE_AND_CACHE", "False").lower() in
("true", "1")),
# vLLM will use optimized topk_softmax + renormalize
"VLLM_USE_TOPK_RENORM":
lambda:
(os.environ.get("VLLM_USE_TOPK_RENORM", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -1227,14 +1227,24 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool) -> tuple[torch.Tensor, ...]:
ops.topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
if envs.VLLM_USE_TOPK_RENORM:
from lightop import op as op
op.topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
Is_renormalize = True,
)
else:
ops.topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_indices
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment