Commit f70d0ec6 authored by zhuwenwen's avatar zhuwenwen
Browse files

update VLLM_USE_LIGHTOP and VLLM_USE_OPT_CAT

parent f6a1053c
......@@ -201,8 +201,8 @@ if TYPE_CHECKING:
VLLM_USE_FLASH_ATTN_PA: bool = False
VLLM_USE_APEX_RN: bool = False
VLLM_USE_GLOBAL_CACHE13: bool = False
VLLM_USE_LIGHT_OP: bool = False
VLLM_USE_TRITON_CAT: bool = False
VLLM_USE_LIGHTOP: bool = False
VLLM_USE_OPT_CAT: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False
......@@ -1385,14 +1385,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in
("true", "1")),
# vLLM will use lightop for moe_fused_gate and moe_align_block_size
"VLLM_USE_LIGHT_OP":
lambda: (os.environ.get("VLLM_USE_LIGHT_OP", "False").lower() in
# vLLM will use lightop for deepseek-v3
"VLLM_USE_LIGHTOP":
lambda: (os.environ.get("VLLM_USE_LIGHTOP", "False").lower() in
("true", "1")),
# vLLM will use global cache for moe
"VLLM_USE_TRITON_CAT":
lambda: (os.environ.get("VLLM_USE_TRITON_CAT", "True").lower() in
# vLLM will use opt cat for deepseek-v3
"VLLM_USE_OPT_CAT":
lambda: (os.environ.get("VLLM_USE_OPT_CAT", "True").lower() in
("true", "1")),
# vLLM will use opt merge_aatn_states,not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT":
......
......@@ -9,7 +9,8 @@ from vllm.triton_utils import triton
from vllm.utils import round_up
import vllm.envs as envs
from lightop import op
if envs.VLLM_USE_LIGHTOP:
from lightop import op as op
def moe_align_block_size(
......@@ -95,7 +96,7 @@ def moe_align_block_size(
dtype=torch.int32,
device=topk_ids.device)
if envs.VLLM_USE_LIGHT_OP:
if envs.VLLM_USE_LIGHTOP:
op.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad, None)
else:
......
......@@ -189,6 +189,11 @@ def get_model_architecture(
os.environ['LM_NN'] = '0'
else:
os.environ['LM_NN'] = '1'
if (architectures == ['DeepseekV3ForCausalLM'] or architectures == ['DeepSeekMTPModel']):
if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1'
if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0'
if os.getenv('FA_PAD') != '1':
......
......@@ -225,7 +225,9 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
if envs.VLLM_USE_OPT_CAT:
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
......@@ -1395,7 +1397,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if envs.VLLM_USE_TRITON_CAT:
if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024:
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2)
......@@ -1556,7 +1558,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if envs.VLLM_USE_TRITON_CAT:
if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024:
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2)
......
......@@ -21,7 +21,9 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm import envs
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
if envs.VLLM_USE_OPT_CAT:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
logger = init_logger(__name__)
......@@ -193,12 +195,13 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
# if envs.VLLM_USE_TRITON_CAT:
# if q_nope.shape[0] <= 1024:
# if envs.VLLM_USE_OPT_CAT:
# if q_nope.shape[0] < 1024:
# q = concat_helper_decode(q_nope, q_pe, dim=2)\
# .unsqueeze(1)
# else:
# q = torch.cat([q_nope, q_pe], dim=-1)\
# .unsqueeze(1) # Add seqlen dim of 1 (decode)
if type(q) is tuple:
q = torch.cat(q, dim=-1)
......
......@@ -5,7 +5,11 @@ from functools import reduce
import pytest
import torch
import math
from lightop import ds_cat
import vllm.envs as envs
if envs.VLLM_USE_LIGHTOP:
from lightop import ds_cat
def test_concat_Acc_prefill(shape_pair, dim):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment