Commit d3824217 authored by zhuwenwen's avatar zhuwenwen
Browse files

configure adaptive optimization operator for qwen3-30b

parents 755d78b4 fe054987
......@@ -187,6 +187,7 @@ if TYPE_CHECKING:
VLLM_USE_CAT_MLA: bool = False
VLLM_REJECT_SAMPLE_OPT: bool = False
VLLM_USE_FUSE_SILU_AND_MUL: bool = False
VLLM_USE_OPT_RESHAPE_AND_CACHE: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1103,46 +1104,57 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASH_ATTN_PA":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_PA", "True").lower() in
("true", "1")),
# vLLM will use apex for rmsnorm
"VLLM_USE_APEX_RN":
lambda: (os.environ.get("VLLM_USE_APEX_RN", "False").lower() in
("true", "1")),
# vLLM will use global cache for moe
"VLLM_USE_GLOBAL_CACHE13":
lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in
("true", "1")),
# vLLM will use lightop for deepseek-v3
"VLLM_USE_LIGHTOP":
lambda: (os.environ.get("VLLM_USE_LIGHTOP", "False").lower() in
("true", "1")),
# vLLM will use elenmentwise not triton_
"VLLM_USE_OPT_ZEROS":
lambda: (os.environ.get("VLLM_USE_OPT_ZEROS", "True").lower() in
("true", "1")),
# vLLM will use opt cat for deepseek-v3
"VLLM_USE_OPT_CAT":
lambda: (os.environ.get("VLLM_USE_OPT_CAT", "False").lower() in
("true", "1")),
# vLLM will use triton moe_sum
"VLLM_USE_OPT_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_OPT_MOE_SUM", "False").lower() in
("true", "1")),
# vLLM will use lightop moe_sum_mul_add
# vLLM will use lightop moe_sum_mul_add for deepseek-v3
"VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD", "False").lower() in
("true", "1")),
# vLLM will use lightop moe_sum
# vLLM will use lightop moe_sum (qwen3-30b)
"VLLM_USE_LIGHTOP_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM", "True").lower() in
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM", "False").lower() in
("true", "1")),
# vLLM will use lightop moe_align_block_size
# vLLM will use lightop moe_align_block_size (qwen3-30b)
"VLLM_USE_LIGHTOP_MOE_ALIGN":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_ALIGN", "True").lower() in
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_ALIGN", "False").lower() in
("true", "1")),
# vLLM will use opt merge_aatn_states, not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT":
lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in
("true", "1")),
# vllm will use rmsquant fused op
"USE_FUSED_RMS_QUANT":
lambda: bool(int(os.getenv("USE_FUSED_RMS_QUANT", "1"))),
......@@ -1175,9 +1187,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_PP_SYNC", "True").lower() in
("true", "1")),
# vLLM will use lightop to fuse fill and moe align
# vLLM will use lightop to fuse fill and moe align (dpsk-v3 + qwen3-30b)
"VLLM_USE_LIGHTOP_FILL_MOE_ALIGN":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_FILL_MOE_ALIGN", "True").lower() in
lambda: (os.environ.get("VLLM_USE_LIGHTOP_FILL_MOE_ALIGN", "False").lower() in
("true", "1")),
# vllm will use custom-allreduce rmsquant fused op
......@@ -1208,10 +1220,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv('VLLM_REJECT_SAMPLE_OPT', 'False').lower() in
("true", "1")),
# vLLM will use fused silu+mul kernel
# vLLM will use fused silu+mul kernel (fp16 + qwen3-30b)
"VLLM_USE_FUSE_SILU_AND_MUL":
lambda: (os.environ.get("VLLM_USE_FUSE_SILU_AND_MUL", "False").lower() in
("true", "1")),
# vLLM will use optimized reshape_and_cache kernel when enabled (fp16 + qwen3-30b)
"VLLM_USE_OPT_RESHAPE_AND_CACHE":
lambda:
(os.environ.get("VLLM_USE_OPT_RESHAPE_AND_CACHE", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -1875,7 +1875,7 @@ def fused_experts_impl(
use_nn_moe=use_nn_moe)
if activation == "silu":
if envs.VLLM_USE_FUSE_SILU_AND_MUL:
if envs.VLLM_USE_FUSE_SILU_AND_MUL and intermediate_cache1.dtype == intermediate_cache2.dtype == "fp16":
from lightop import fuse_silu_and_mul
fuse_silu_and_mul(intermediate_cache1.view(-1, N),intermediate_cache2)
else:
......
......@@ -251,15 +251,24 @@ def get_model_architecture(
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1'
# if not envs.is_set("VLLM_USE_LIGHTOP_FILL_MOE_ALIGN"):
# os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_FILL_MOE_ALIGN"):
os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_CAT_MLA"):
os.environ['VLLM_USE_CAT_MLA'] = '1'
else:
if not envs.is_set("VLLM_USE_PD_SPLIT"):
os.environ['VLLM_USE_PD_SPLIT'] = '1'
if architectures in [['Qwen3MoeForCausalLM']]:
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_ALIGN"):
os.environ['VLLM_USE_LIGHTOP_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_FILL_MOE_ALIGN"):
os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM"):
os.environ['VLLM_USE_LIGHTOP_MOE_SUM'] = '1'
if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"):
os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1'
if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"):
os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1'
if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0'
......@@ -273,15 +282,29 @@ def get_model_architecture(
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1'
# if not envs.is_set("VLLM_USE_LIGHTOP_FILL_MOE_ALIGN"):
# os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_FILL_MOE_ALIGN"):
os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_CAT_MLA"):
os.environ['VLLM_USE_CAT_MLA'] = '1'
else:
if not envs.is_set("VLLM_USE_PD_SPLIT"):
os.environ['VLLM_USE_PD_SPLIT'] = '1'
if architectures in [['Qwen3MoeForCausalLM']]:
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_ALIGN"):
os.environ['VLLM_USE_LIGHTOP_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_FILL_MOE_ALIGN"):
os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM"):
os.environ['VLLM_USE_LIGHTOP_MOE_SUM'] = '1'
if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"):
os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1'
if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"):
os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1'
if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0'
if os.getenv('FA_PAD') != '1':
os.environ['FA_PAD'] = '0'
# awq相关配置
try:
......
......@@ -565,6 +565,16 @@ class FlashAttentionImpl(AttentionImpl):
layer._k_scale,
layer._v_scale,
)
else:
from lightop import reshape_and_cache_cuda
if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE and key.dtype == value.dtype == "fp16":
reshape_and_cache_cuda(
key, value,
key_cache, value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale, layer._v_scale
)
else:
reshape_and_cache_cuda(
key,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment