Commit 15a55773 authored by zhuwenwen's avatar zhuwenwen
Browse files

add VLLM_USE_OPT_RESHAPE_AND_CACHE、VLLM_USE_FUSE_SILU_AND_MUL and...

add VLLM_USE_OPT_RESHAPE_AND_CACHE、VLLM_USE_FUSE_SILU_AND_MUL and VLLM_USE_TOPK_RENORM for qwen3-30b
parent 1db5839e
...@@ -239,6 +239,9 @@ if TYPE_CHECKING: ...@@ -239,6 +239,9 @@ if TYPE_CHECKING:
VLLM_USE_PP_SYNC: bool = False VLLM_USE_PP_SYNC: bool = False
VLLM_USE_PIECEWISE: bool = False VLLM_USE_PIECEWISE: bool = False
VLLM_USE_V32_ENCODE: bool = False VLLM_USE_V32_ENCODE: bool = False
VLLM_USE_FUSE_SILU_AND_MUL: bool = False
VLLM_USE_OPT_RESHAPE_AND_CACHE: bool = False
VLLM_USE_TOPK_RENORM: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -1656,7 +1659,21 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1656,7 +1659,21 @@ environment_variables: dict[str, Callable[[], Any]] = {
("true", "1")), ("true", "1")),
# vllm will use encoding_dsv32.py for dpsk-v32 # vllm will use encoding_dsv32.py for dpsk-v32
"VLLM_USE_V32_ENCODE": "VLLM_USE_V32_ENCODE":
lambda: (os.getenv('VLLM_USE_V32_ENCODE', 'False').lower() in lambda: (os.environ.get('VLLM_USE_V32_ENCODE', 'False').lower() in
("true", "1")),
# vLLM will use fused silu+mul kernel (fp16 + qwen3-30b)
"VLLM_USE_FUSE_SILU_AND_MUL":
lambda: (os.environ.get("VLLM_USE_FUSE_SILU_AND_MUL", "False").lower() in
("true", "1")),
# vLLM will use optimized reshape_and_cache kernel when enabled (fp16 + qwen3-30b)
"VLLM_USE_OPT_RESHAPE_AND_CACHE":
lambda:
(os.environ.get("VLLM_USE_OPT_RESHAPE_AND_CACHE", "False").lower() in
("true", "1")),
# vLLM will use optimized topk_softmax + renormalize
"VLLM_USE_TOPK_RENORM":
lambda:
(os.environ.get("VLLM_USE_TOPK_RENORM", "True").lower() in
("true", "1")), ("true", "1")),
} }
......
...@@ -1331,6 +1331,16 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, ...@@ -1331,6 +1331,16 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor, token_expert_indices: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
renormalize: bool) -> tuple[torch.Tensor, ...]: renormalize: bool) -> tuple[torch.Tensor, ...]:
if envs.VLLM_USE_TOPK_RENORM:
from lightop import op as op
op.topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
True,
)
else:
ops.topk_softmax( ops.topk_softmax(
topk_weights, topk_weights,
topk_indices, topk_indices,
...@@ -2125,6 +2135,10 @@ def fused_experts_impl( ...@@ -2125,6 +2135,10 @@ def fused_experts_impl(
# Activation function with multiplication # Activation function with multiplication
if activation == "silu": if activation == "silu":
if envs.VLLM_USE_FUSE_SILU_AND_MUL and intermediate_cache1.dtype == intermediate_cache2.dtype == torch.float16:
from lightop import fuse_silu_and_mul
fuse_silu_and_mul(intermediate_cache1.view(-1, N),intermediate_cache2)
else:
torch.ops._C.silu_and_mul(intermediate_cache2, torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N)) intermediate_cache1.view(-1, N))
elif activation == "gelu": elif activation == "gelu":
......
...@@ -201,6 +201,15 @@ def _get_model_architecture( ...@@ -201,6 +201,15 @@ def _get_model_architecture(
else: else:
if not envs.is_set("VLLM_USE_PD_SPLIT"): if not envs.is_set("VLLM_USE_PD_SPLIT"):
os.environ['VLLM_USE_PD_SPLIT'] = '1' os.environ['VLLM_USE_PD_SPLIT'] = '1'
if architectures in [['Qwen3MoeForCausalLM']]:
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_ALIGN"):
os.environ['VLLM_USE_LIGHTOP_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM"):
os.environ['VLLM_USE_LIGHTOP_MOE_SUM'] = '1'
if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"):
os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1'
if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"):
os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1'
if architectures in [['DeepseekV32ForCausalLM']]: if architectures in [['DeepseekV32ForCausalLM']]:
if not envs.is_set("VLLM_USE_V32_ENCODE"): if not envs.is_set("VLLM_USE_V32_ENCODE"):
...@@ -219,6 +228,15 @@ def _get_model_architecture( ...@@ -219,6 +228,15 @@ def _get_model_architecture(
else: else:
if not envs.is_set("VLLM_USE_PD_SPLIT"): if not envs.is_set("VLLM_USE_PD_SPLIT"):
os.environ['VLLM_USE_PD_SPLIT'] = '1' os.environ['VLLM_USE_PD_SPLIT'] = '1'
if architectures in [['Qwen3MoeForCausalLM']]:
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_ALIGN"):
os.environ['VLLM_USE_LIGHTOP_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM"):
os.environ['VLLM_USE_LIGHTOP_MOE_SUM'] = '1'
if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"):
os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1'
if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"):
os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1'
if architectures in [['DeepseekV32ForCausalLM']]: if architectures in [['DeepseekV32ForCausalLM']]:
if not envs.is_set("VLLM_USE_V32_ENCODE"): if not envs.is_set("VLLM_USE_V32_ENCODE"):
......
...@@ -550,6 +550,20 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -550,6 +550,20 @@ class FlashAttentionImpl(AttentionImpl):
layer._v_scale, layer._v_scale,
) )
else: else:
if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE and key.dtype == value.dtype == torch.float16:
from lightop import reshape_and_cache_cuda
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale
)
else:
from vllm.attention.utils.fa_utils import reshape_and_cache_cuda
reshape_and_cache_cuda( reshape_and_cache_cuda(
key, key,
value, value,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment