Commit 15a55773 authored by zhuwenwen's avatar zhuwenwen
Browse files

add VLLM_USE_OPT_RESHAPE_AND_CACHE、VLLM_USE_FUSE_SILU_AND_MUL and...

add VLLM_USE_OPT_RESHAPE_AND_CACHE、VLLM_USE_FUSE_SILU_AND_MUL and VLLM_USE_TOPK_RENORM for qwen3-30b
parent 1db5839e
...@@ -239,6 +239,9 @@ if TYPE_CHECKING: ...@@ -239,6 +239,9 @@ if TYPE_CHECKING:
VLLM_USE_PP_SYNC: bool = False VLLM_USE_PP_SYNC: bool = False
VLLM_USE_PIECEWISE: bool = False VLLM_USE_PIECEWISE: bool = False
VLLM_USE_V32_ENCODE: bool = False VLLM_USE_V32_ENCODE: bool = False
VLLM_USE_FUSE_SILU_AND_MUL: bool = False
VLLM_USE_OPT_RESHAPE_AND_CACHE: bool = False
VLLM_USE_TOPK_RENORM: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -1656,8 +1659,22 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1656,8 +1659,22 @@ environment_variables: dict[str, Callable[[], Any]] = {
("true", "1")), ("true", "1")),
# vllm will use encoding_dsv32.py for dpsk-v32 # vllm will use encoding_dsv32.py for dpsk-v32
"VLLM_USE_V32_ENCODE": "VLLM_USE_V32_ENCODE":
lambda: (os.getenv('VLLM_USE_V32_ENCODE', 'False').lower() in lambda: (os.environ.get('VLLM_USE_V32_ENCODE', 'False').lower() in
("true", "1")), ("true", "1")),
# vLLM will use fused silu+mul kernel (fp16 + qwen3-30b)
"VLLM_USE_FUSE_SILU_AND_MUL":
lambda: (os.environ.get("VLLM_USE_FUSE_SILU_AND_MUL", "False").lower() in
("true", "1")),
# vLLM will use optimized reshape_and_cache kernel when enabled (fp16 + qwen3-30b)
"VLLM_USE_OPT_RESHAPE_AND_CACHE":
lambda:
(os.environ.get("VLLM_USE_OPT_RESHAPE_AND_CACHE", "False").lower() in
("true", "1")),
# vLLM will use optimized topk_softmax + renormalize
"VLLM_USE_TOPK_RENORM":
lambda:
(os.environ.get("VLLM_USE_TOPK_RENORM", "True").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -1331,14 +1331,24 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, ...@@ -1331,14 +1331,24 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor, token_expert_indices: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
renormalize: bool) -> tuple[torch.Tensor, ...]: renormalize: bool) -> tuple[torch.Tensor, ...]:
ops.topk_softmax( if envs.VLLM_USE_TOPK_RENORM:
topk_weights, from lightop import op as op
topk_indices, op.topk_softmax(
token_expert_indices, topk_weights,
gating_output, topk_indices,
) token_expert_indices,
if renormalize: gating_output,
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) True,
)
else:
ops.topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_indices return topk_weights, topk_indices
...@@ -2125,8 +2135,12 @@ def fused_experts_impl( ...@@ -2125,8 +2135,12 @@ def fused_experts_impl(
# Activation function with multiplication # Activation function with multiplication
if activation == "silu": if activation == "silu":
torch.ops._C.silu_and_mul(intermediate_cache2, if envs.VLLM_USE_FUSE_SILU_AND_MUL and intermediate_cache1.dtype == intermediate_cache2.dtype == torch.float16:
intermediate_cache1.view(-1, N)) from lightop import fuse_silu_and_mul
fuse_silu_and_mul(intermediate_cache1.view(-1, N),intermediate_cache2)
else:
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
elif activation == "gelu": elif activation == "gelu":
torch.ops._C.gelu_and_mul(intermediate_cache2, torch.ops._C.gelu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N)) intermediate_cache1.view(-1, N))
...@@ -2175,10 +2189,10 @@ def fused_experts_impl( ...@@ -2175,10 +2189,10 @@ def fused_experts_impl(
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe)
if envs.VLLM_USE_LIGHTOP_MOE_SUM: if envs.VLLM_USE_LIGHTOP_MOE_SUM:
from lightop import op as op from lightop import op as op
op.moe_sum(input=intermediate_cache3.view(*intermediate_cache3.size()), op.moe_sum(input=intermediate_cache3.view(*intermediate_cache3.size()),
output=out_hidden_states[begin_chunk_idx:end_chunk_idx], bias=None, output=out_hidden_states[begin_chunk_idx:end_chunk_idx], bias=None,
expert_mask=None, num_local_tokens=None, factor=1.0) expert_mask=None, num_local_tokens=None, factor=1.0)
elif envs.VLLM_USE_OPT_MOE_SUM: elif envs.VLLM_USE_OPT_MOE_SUM:
moe_reduce_dispatch(intermediate_cache3.view(*intermediate_cache3.size()), out_hidden_states[begin_chunk_idx:end_chunk_idx], begin_chunk_idx, end_chunk_idx) moe_reduce_dispatch(intermediate_cache3.view(*intermediate_cache3.size()), out_hidden_states[begin_chunk_idx:end_chunk_idx], begin_chunk_idx, end_chunk_idx)
else: else:
......
...@@ -201,6 +201,15 @@ def _get_model_architecture( ...@@ -201,6 +201,15 @@ def _get_model_architecture(
else: else:
if not envs.is_set("VLLM_USE_PD_SPLIT"): if not envs.is_set("VLLM_USE_PD_SPLIT"):
os.environ['VLLM_USE_PD_SPLIT'] = '1' os.environ['VLLM_USE_PD_SPLIT'] = '1'
if architectures in [['Qwen3MoeForCausalLM']]:
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_ALIGN"):
os.environ['VLLM_USE_LIGHTOP_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM"):
os.environ['VLLM_USE_LIGHTOP_MOE_SUM'] = '1'
if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"):
os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1'
if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"):
os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1'
if architectures in [['DeepseekV32ForCausalLM']]: if architectures in [['DeepseekV32ForCausalLM']]:
if not envs.is_set("VLLM_USE_V32_ENCODE"): if not envs.is_set("VLLM_USE_V32_ENCODE"):
...@@ -219,6 +228,15 @@ def _get_model_architecture( ...@@ -219,6 +228,15 @@ def _get_model_architecture(
else: else:
if not envs.is_set("VLLM_USE_PD_SPLIT"): if not envs.is_set("VLLM_USE_PD_SPLIT"):
os.environ['VLLM_USE_PD_SPLIT'] = '1' os.environ['VLLM_USE_PD_SPLIT'] = '1'
if architectures in [['Qwen3MoeForCausalLM']]:
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_ALIGN"):
os.environ['VLLM_USE_LIGHTOP_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM"):
os.environ['VLLM_USE_LIGHTOP_MOE_SUM'] = '1'
if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"):
os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1'
if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"):
os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1'
if architectures in [['DeepseekV32ForCausalLM']]: if architectures in [['DeepseekV32ForCausalLM']]:
if not envs.is_set("VLLM_USE_V32_ENCODE"): if not envs.is_set("VLLM_USE_V32_ENCODE"):
......
...@@ -550,16 +550,30 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -550,16 +550,30 @@ class FlashAttentionImpl(AttentionImpl):
layer._v_scale, layer._v_scale,
) )
else: else:
reshape_and_cache_cuda( if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE and key.dtype == value.dtype == torch.float16:
key, from lightop import reshape_and_cache_cuda
value, reshape_and_cache_cuda(
key_cache, key,
value_cache, value,
attn_metadata.slot_mapping, key_cache,
self.kv_cache_dtype, value_cache,
layer._k_scale, attn_metadata.slot_mapping,
layer._v_scale, self.kv_cache_dtype,
) layer._k_scale,
layer._v_scale
)
else:
from vllm.attention.utils.fa_utils import reshape_and_cache_cuda
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"): if self.kv_cache_dtype.startswith("fp8"):
# queries are quantized in the attention layer # queries are quantized in the attention layer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment