Commit d3824217 authored by zhuwenwen's avatar zhuwenwen
Browse files

configure adaptive optimization operator for qwen3-30b

parents 755d78b4 fe054987
...@@ -187,6 +187,7 @@ if TYPE_CHECKING: ...@@ -187,6 +187,7 @@ if TYPE_CHECKING:
VLLM_USE_CAT_MLA: bool = False VLLM_USE_CAT_MLA: bool = False
VLLM_REJECT_SAMPLE_OPT: bool = False VLLM_REJECT_SAMPLE_OPT: bool = False
VLLM_USE_FUSE_SILU_AND_MUL: bool = False VLLM_USE_FUSE_SILU_AND_MUL: bool = False
VLLM_USE_OPT_RESHAPE_AND_CACHE: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1103,46 +1104,57 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1103,46 +1104,57 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASH_ATTN_PA": "VLLM_USE_FLASH_ATTN_PA":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_PA", "True").lower() in lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_PA", "True").lower() in
("true", "1")), ("true", "1")),
# vLLM will use apex for rmsnorm # vLLM will use apex for rmsnorm
"VLLM_USE_APEX_RN": "VLLM_USE_APEX_RN":
lambda: (os.environ.get("VLLM_USE_APEX_RN", "False").lower() in lambda: (os.environ.get("VLLM_USE_APEX_RN", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use global cache for moe # vLLM will use global cache for moe
"VLLM_USE_GLOBAL_CACHE13": "VLLM_USE_GLOBAL_CACHE13":
lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use lightop for deepseek-v3 # vLLM will use lightop for deepseek-v3
"VLLM_USE_LIGHTOP": "VLLM_USE_LIGHTOP":
lambda: (os.environ.get("VLLM_USE_LIGHTOP", "False").lower() in lambda: (os.environ.get("VLLM_USE_LIGHTOP", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use elenmentwise not triton_ # vLLM will use elenmentwise not triton_
"VLLM_USE_OPT_ZEROS": "VLLM_USE_OPT_ZEROS":
lambda: (os.environ.get("VLLM_USE_OPT_ZEROS", "True").lower() in lambda: (os.environ.get("VLLM_USE_OPT_ZEROS", "True").lower() in
("true", "1")), ("true", "1")),
# vLLM will use opt cat for deepseek-v3 # vLLM will use opt cat for deepseek-v3
"VLLM_USE_OPT_CAT": "VLLM_USE_OPT_CAT":
lambda: (os.environ.get("VLLM_USE_OPT_CAT", "False").lower() in lambda: (os.environ.get("VLLM_USE_OPT_CAT", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use triton moe_sum # vLLM will use triton moe_sum
"VLLM_USE_OPT_MOE_SUM": "VLLM_USE_OPT_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_OPT_MOE_SUM", "False").lower() in lambda: (os.environ.get("VLLM_USE_OPT_MOE_SUM", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use lightop moe_sum_mul_add
# vLLM will use lightop moe_sum_mul_add for deepseek-v3
"VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD": "VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD", "False").lower() in lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use lightop moe_sum
# vLLM will use lightop moe_sum (qwen3-30b)
"VLLM_USE_LIGHTOP_MOE_SUM": "VLLM_USE_LIGHTOP_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM", "True").lower() in lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use lightop moe_align_block_size
# vLLM will use lightop moe_align_block_size (qwen3-30b)
"VLLM_USE_LIGHTOP_MOE_ALIGN": "VLLM_USE_LIGHTOP_MOE_ALIGN":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_ALIGN", "True").lower() in lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_ALIGN", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use opt merge_aatn_states, not triton # vLLM will use opt merge_aatn_states, not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT": "VLLM_USE_MERGE_ATTN_STATES_OPT":
lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in
("true", "1")), ("true", "1")),
# vllm will use rmsquant fused op # vllm will use rmsquant fused op
"USE_FUSED_RMS_QUANT": "USE_FUSED_RMS_QUANT":
lambda: bool(int(os.getenv("USE_FUSED_RMS_QUANT", "1"))), lambda: bool(int(os.getenv("USE_FUSED_RMS_QUANT", "1"))),
...@@ -1175,9 +1187,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1175,9 +1187,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_PP_SYNC", "True").lower() in lambda: (os.environ.get("VLLM_USE_PP_SYNC", "True").lower() in
("true", "1")), ("true", "1")),
# vLLM will use lightop to fuse fill and moe align # vLLM will use lightop to fuse fill and moe align (dpsk-v3 + qwen3-30b)
"VLLM_USE_LIGHTOP_FILL_MOE_ALIGN": "VLLM_USE_LIGHTOP_FILL_MOE_ALIGN":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_FILL_MOE_ALIGN", "True").lower() in lambda: (os.environ.get("VLLM_USE_LIGHTOP_FILL_MOE_ALIGN", "False").lower() in
("true", "1")), ("true", "1")),
# vllm will use custom-allreduce rmsquant fused op # vllm will use custom-allreduce rmsquant fused op
...@@ -1208,10 +1220,16 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1208,10 +1220,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv('VLLM_REJECT_SAMPLE_OPT', 'False').lower() in lambda: (os.getenv('VLLM_REJECT_SAMPLE_OPT', 'False').lower() in
("true", "1")), ("true", "1")),
# vLLM will use fused silu+mul kernel # vLLM will use fused silu+mul kernel (fp16 + qwen3-30b)
"VLLM_USE_FUSE_SILU_AND_MUL": "VLLM_USE_FUSE_SILU_AND_MUL":
lambda: (os.environ.get("VLLM_USE_FUSE_SILU_AND_MUL", "False").lower() in lambda: (os.environ.get("VLLM_USE_FUSE_SILU_AND_MUL", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use optimized reshape_and_cache kernel when enabled (fp16 + qwen3-30b)
"VLLM_USE_OPT_RESHAPE_AND_CACHE":
lambda:
(os.environ.get("VLLM_USE_OPT_RESHAPE_AND_CACHE", "False").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -1875,7 +1875,7 @@ def fused_experts_impl( ...@@ -1875,7 +1875,7 @@ def fused_experts_impl(
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe)
if activation == "silu": if activation == "silu":
if envs.VLLM_USE_FUSE_SILU_AND_MUL: if envs.VLLM_USE_FUSE_SILU_AND_MUL and intermediate_cache1.dtype == intermediate_cache2.dtype == "fp16":
from lightop import fuse_silu_and_mul from lightop import fuse_silu_and_mul
fuse_silu_and_mul(intermediate_cache1.view(-1, N),intermediate_cache2) fuse_silu_and_mul(intermediate_cache1.view(-1, N),intermediate_cache2)
else: else:
......
...@@ -251,15 +251,24 @@ def get_model_architecture( ...@@ -251,15 +251,24 @@ def get_model_architecture(
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1' os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"): if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1' os.environ['VLLM_USE_OPT_CAT'] = '1'
# if not envs.is_set("VLLM_USE_LIGHTOP_FILL_MOE_ALIGN"): if not envs.is_set("VLLM_USE_LIGHTOP_FILL_MOE_ALIGN"):
# os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1' os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_CAT_MLA"): if not envs.is_set("VLLM_USE_CAT_MLA"):
os.environ['VLLM_USE_CAT_MLA'] = '1' os.environ['VLLM_USE_CAT_MLA'] = '1'
else: else:
if not envs.is_set("VLLM_USE_PD_SPLIT"): if not envs.is_set("VLLM_USE_PD_SPLIT"):
os.environ['VLLM_USE_PD_SPLIT'] = '1' os.environ['VLLM_USE_PD_SPLIT'] = '1'
if architectures in [['Qwen3MoeForCausalLM']]:
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_ALIGN"):
os.environ['VLLM_USE_LIGHTOP_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_FILL_MOE_ALIGN"):
os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM"):
os.environ['VLLM_USE_LIGHTOP_MOE_SUM'] = '1'
if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"): if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"):
os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1' os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1'
if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"):
os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1'
if os.getenv('GEMM_PAD') != '1': if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0' os.environ['GEMM_PAD'] = '0'
...@@ -273,15 +282,29 @@ def get_model_architecture( ...@@ -273,15 +282,29 @@ def get_model_architecture(
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1' os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"): if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1' os.environ['VLLM_USE_OPT_CAT'] = '1'
# if not envs.is_set("VLLM_USE_LIGHTOP_FILL_MOE_ALIGN"): if not envs.is_set("VLLM_USE_LIGHTOP_FILL_MOE_ALIGN"):
# os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1' os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_CAT_MLA"): if not envs.is_set("VLLM_USE_CAT_MLA"):
os.environ['VLLM_USE_CAT_MLA'] = '1' os.environ['VLLM_USE_CAT_MLA'] = '1'
else: else:
if not envs.is_set("VLLM_USE_PD_SPLIT"): if not envs.is_set("VLLM_USE_PD_SPLIT"):
os.environ['VLLM_USE_PD_SPLIT'] = '1' os.environ['VLLM_USE_PD_SPLIT'] = '1'
if architectures in [['Qwen3MoeForCausalLM']]:
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_ALIGN"):
os.environ['VLLM_USE_LIGHTOP_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_FILL_MOE_ALIGN"):
os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM"):
os.environ['VLLM_USE_LIGHTOP_MOE_SUM'] = '1'
if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"): if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"):
os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1' os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1'
if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"):
os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1'
if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0'
if os.getenv('FA_PAD') != '1':
os.environ['FA_PAD'] = '0'
# awq相关配置 # awq相关配置
try: try:
......
...@@ -565,6 +565,16 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -565,6 +565,16 @@ class FlashAttentionImpl(AttentionImpl):
layer._k_scale, layer._k_scale,
layer._v_scale, layer._v_scale,
) )
else:
from lightop import reshape_and_cache_cuda
if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE and key.dtype == value.dtype == "fp16":
reshape_and_cache_cuda(
key, value,
key_cache, value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale, layer._v_scale
)
else: else:
reshape_and_cache_cuda( reshape_and_cache_cuda(
key, key,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment