Commit 15a55773 authored by zhuwenwen's avatar zhuwenwen
Browse files

add VLLM_USE_OPT_RESHAPE_AND_CACHE、VLLM_USE_FUSE_SILU_AND_MUL and...

add VLLM_USE_OPT_RESHAPE_AND_CACHE、VLLM_USE_FUSE_SILU_AND_MUL and VLLM_USE_TOPK_RENORM for qwen3-30b
parent 1db5839e
......@@ -239,6 +239,9 @@ if TYPE_CHECKING:
VLLM_USE_PP_SYNC: bool = False
VLLM_USE_PIECEWISE: bool = False
VLLM_USE_V32_ENCODE: bool = False
VLLM_USE_FUSE_SILU_AND_MUL: bool = False
VLLM_USE_OPT_RESHAPE_AND_CACHE: bool = False
VLLM_USE_TOPK_RENORM: bool = False
def get_default_cache_root():
......@@ -1656,8 +1659,22 @@ environment_variables: dict[str, Callable[[], Any]] = {
("true", "1")),
# vllm will use encoding_dsv32.py for dpsk-v32
"VLLM_USE_V32_ENCODE":
lambda: (os.getenv('VLLM_USE_V32_ENCODE', 'False').lower() in
lambda: (os.environ.get('VLLM_USE_V32_ENCODE', 'False').lower() in
("true", "1")),
# vLLM will use fused silu+mul kernel (fp16 + qwen3-30b)
"VLLM_USE_FUSE_SILU_AND_MUL":
lambda: (os.environ.get("VLLM_USE_FUSE_SILU_AND_MUL", "False").lower() in
("true", "1")),
# vLLM will use optimized reshape_and_cache kernel when enabled (fp16 + qwen3-30b)
"VLLM_USE_OPT_RESHAPE_AND_CACHE":
lambda:
(os.environ.get("VLLM_USE_OPT_RESHAPE_AND_CACHE", "False").lower() in
("true", "1")),
# vLLM will use optimized topk_softmax + renormalize
"VLLM_USE_TOPK_RENORM":
lambda:
(os.environ.get("VLLM_USE_TOPK_RENORM", "True").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -1331,14 +1331,24 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool) -> tuple[torch.Tensor, ...]:
ops.topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
if envs.VLLM_USE_TOPK_RENORM:
from lightop import op as op
op.topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
True,
)
else:
ops.topk_softmax(
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_indices
......@@ -2125,8 +2135,12 @@ def fused_experts_impl(
# Activation function with multiplication
if activation == "silu":
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
if envs.VLLM_USE_FUSE_SILU_AND_MUL and intermediate_cache1.dtype == intermediate_cache2.dtype == torch.float16:
from lightop import fuse_silu_and_mul
fuse_silu_and_mul(intermediate_cache1.view(-1, N),intermediate_cache2)
else:
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
elif activation == "gelu":
torch.ops._C.gelu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
......@@ -2175,10 +2189,10 @@ def fused_experts_impl(
use_nn_moe=use_nn_moe)
if envs.VLLM_USE_LIGHTOP_MOE_SUM:
from lightop import op as op
op.moe_sum(input=intermediate_cache3.view(*intermediate_cache3.size()),
output=out_hidden_states[begin_chunk_idx:end_chunk_idx], bias=None,
expert_mask=None, num_local_tokens=None, factor=1.0)
from lightop import op as op
op.moe_sum(input=intermediate_cache3.view(*intermediate_cache3.size()),
output=out_hidden_states[begin_chunk_idx:end_chunk_idx], bias=None,
expert_mask=None, num_local_tokens=None, factor=1.0)
elif envs.VLLM_USE_OPT_MOE_SUM:
moe_reduce_dispatch(intermediate_cache3.view(*intermediate_cache3.size()), out_hidden_states[begin_chunk_idx:end_chunk_idx], begin_chunk_idx, end_chunk_idx)
else:
......
......@@ -201,6 +201,15 @@ def _get_model_architecture(
else:
if not envs.is_set("VLLM_USE_PD_SPLIT"):
os.environ['VLLM_USE_PD_SPLIT'] = '1'
if architectures in [['Qwen3MoeForCausalLM']]:
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_ALIGN"):
os.environ['VLLM_USE_LIGHTOP_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM"):
os.environ['VLLM_USE_LIGHTOP_MOE_SUM'] = '1'
if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"):
os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1'
if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"):
os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1'
if architectures in [['DeepseekV32ForCausalLM']]:
if not envs.is_set("VLLM_USE_V32_ENCODE"):
......@@ -219,6 +228,15 @@ def _get_model_architecture(
else:
if not envs.is_set("VLLM_USE_PD_SPLIT"):
os.environ['VLLM_USE_PD_SPLIT'] = '1'
if architectures in [['Qwen3MoeForCausalLM']]:
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_ALIGN"):
os.environ['VLLM_USE_LIGHTOP_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM"):
os.environ['VLLM_USE_LIGHTOP_MOE_SUM'] = '1'
if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"):
os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1'
if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"):
os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1'
if architectures in [['DeepseekV32ForCausalLM']]:
if not envs.is_set("VLLM_USE_V32_ENCODE"):
......
......@@ -550,16 +550,30 @@ class FlashAttentionImpl(AttentionImpl):
layer._v_scale,
)
else:
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE and key.dtype == value.dtype == torch.float16:
from lightop import reshape_and_cache_cuda
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale
)
else:
from vllm.attention.utils.fa_utils import reshape_and_cache_cuda
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
# queries are quantized in the attention layer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment