Commit 7ff04b72 authored by zhuwenwen's avatar zhuwenwen
Browse files

add VLLM_USE_TOPK_RENORM to use optimized topk_softmax + renormalize

parent d3824217
...@@ -188,6 +188,7 @@ if TYPE_CHECKING: ...@@ -188,6 +188,7 @@ if TYPE_CHECKING:
VLLM_REJECT_SAMPLE_OPT: bool = False VLLM_REJECT_SAMPLE_OPT: bool = False
VLLM_USE_FUSE_SILU_AND_MUL: bool = False VLLM_USE_FUSE_SILU_AND_MUL: bool = False
VLLM_USE_OPT_RESHAPE_AND_CACHE: bool = False VLLM_USE_OPT_RESHAPE_AND_CACHE: bool = False
VLLM_USE_TOPK_RENORM: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1230,6 +1231,13 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1230,6 +1231,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: lambda:
(os.environ.get("VLLM_USE_OPT_RESHAPE_AND_CACHE", "False").lower() in (os.environ.get("VLLM_USE_OPT_RESHAPE_AND_CACHE", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use optimized topk_softmax + renormalize
"VLLM_USE_TOPK_RENORM":
lambda:
(os.environ.get("VLLM_USE_TOPK_RENORM", "False").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -1227,14 +1227,24 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, ...@@ -1227,14 +1227,24 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor, token_expert_indices: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
renormalize: bool) -> tuple[torch.Tensor, ...]: renormalize: bool) -> tuple[torch.Tensor, ...]:
ops.topk_softmax( if envs.VLLM_USE_TOPK_RENORM:
topk_weights, from lightop import op as op
topk_indices, op.topk_softmax(
token_expert_indices, topk_weights,
gating_output, topk_indices,
) token_expert_indices,
if renormalize: gating_output,
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) Is_renormalize = True,
)
else:
ops.topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_indices return topk_weights, topk_indices
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment