Commit f70d0ec6 authored by zhuwenwen's avatar zhuwenwen
Browse files

update VLLM_USE_LIGHTOP and VLLM_USE_OPT_CAT

parent f6a1053c
...@@ -201,8 +201,8 @@ if TYPE_CHECKING: ...@@ -201,8 +201,8 @@ if TYPE_CHECKING:
VLLM_USE_FLASH_ATTN_PA: bool = False VLLM_USE_FLASH_ATTN_PA: bool = False
VLLM_USE_APEX_RN: bool = False VLLM_USE_APEX_RN: bool = False
VLLM_USE_GLOBAL_CACHE13: bool = False VLLM_USE_GLOBAL_CACHE13: bool = False
VLLM_USE_LIGHT_OP: bool = False VLLM_USE_LIGHTOP: bool = False
VLLM_USE_TRITON_CAT: bool = False VLLM_USE_OPT_CAT: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False USE_FUSED_RMS_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False USE_FUSED_SILU_MUL_QUANT: bool = False
...@@ -1385,14 +1385,14 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1385,14 +1385,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use lightop for moe_fused_gate and moe_align_block_size # vLLM will use lightop for deepseek-v3
"VLLM_USE_LIGHT_OP": "VLLM_USE_LIGHTOP":
lambda: (os.environ.get("VLLM_USE_LIGHT_OP", "False").lower() in lambda: (os.environ.get("VLLM_USE_LIGHTOP", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use global cache for moe # vLLM will use opt cat for deepseek-v3
"VLLM_USE_TRITON_CAT": "VLLM_USE_OPT_CAT":
lambda: (os.environ.get("VLLM_USE_TRITON_CAT", "True").lower() in lambda: (os.environ.get("VLLM_USE_OPT_CAT", "True").lower() in
("true", "1")), ("true", "1")),
# vLLM will use opt merge_aatn_states,not triton # vLLM will use opt merge_aatn_states,not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT": "VLLM_USE_MERGE_ATTN_STATES_OPT":
......
...@@ -9,7 +9,8 @@ from vllm.triton_utils import triton ...@@ -9,7 +9,8 @@ from vllm.triton_utils import triton
from vllm.utils import round_up from vllm.utils import round_up
import vllm.envs as envs import vllm.envs as envs
from lightop import op if envs.VLLM_USE_LIGHTOP:
from lightop import op as op
def moe_align_block_size( def moe_align_block_size(
...@@ -95,7 +96,7 @@ def moe_align_block_size( ...@@ -95,7 +96,7 @@ def moe_align_block_size(
dtype=torch.int32, dtype=torch.int32,
device=topk_ids.device) device=topk_ids.device)
if envs.VLLM_USE_LIGHT_OP: if envs.VLLM_USE_LIGHTOP:
op.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, op.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad, None) expert_ids, num_tokens_post_pad, None)
else: else:
......
...@@ -189,6 +189,11 @@ def get_model_architecture( ...@@ -189,6 +189,11 @@ def get_model_architecture(
os.environ['LM_NN'] = '0' os.environ['LM_NN'] = '0'
else: else:
os.environ['LM_NN'] = '1' os.environ['LM_NN'] = '1'
if (architectures == ['DeepseekV3ForCausalLM'] or architectures == ['DeepSeekMTPModel']):
if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1'
if os.getenv('GEMM_PAD') != '1': if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0' os.environ['GEMM_PAD'] = '0'
if os.getenv('FA_PAD') != '1': if os.getenv('FA_PAD') != '1':
......
...@@ -225,7 +225,9 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, ...@@ -225,7 +225,9 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
split_decodes_and_prefills) split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
if envs.VLLM_USE_OPT_CAT:
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
try: try:
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import flash_attn_varlen_func
...@@ -1395,7 +1397,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1395,7 +1397,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_nope, v = kv_nope\ k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if envs.VLLM_USE_TRITON_CAT: if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024: if k_nope.shape[0] > 1024:
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)), k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2) dim=2)
...@@ -1556,7 +1558,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1556,7 +1558,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_nope, v = kv_nope\ k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if envs.VLLM_USE_TRITON_CAT: if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024: if k_nope.shape[0] > 1024:
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)), k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2) dim=2)
......
...@@ -21,7 +21,9 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend, ...@@ -21,7 +21,9 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm import envs from vllm import envs
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
if envs.VLLM_USE_OPT_CAT:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -193,12 +195,13 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -193,12 +195,13 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert kv_c_and_k_pe_cache.numel() > 0 assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
# if envs.VLLM_USE_TRITON_CAT: # if envs.VLLM_USE_OPT_CAT:
# if q_nope.shape[0] <= 1024: # if q_nope.shape[0] < 1024:
# q = concat_helper_decode(q_nope, q_pe, dim=2)\ # q = concat_helper_decode(q_nope, q_pe, dim=2)\
# .unsqueeze(1) # .unsqueeze(1)
# else: # else:
# q = torch.cat([q_nope, q_pe], dim=-1)\ # q = torch.cat([q_nope, q_pe], dim=-1)\
# .unsqueeze(1) # Add seqlen dim of 1 (decode)
if type(q) is tuple: if type(q) is tuple:
q = torch.cat(q, dim=-1) q = torch.cat(q, dim=-1)
......
...@@ -5,7 +5,11 @@ from functools import reduce ...@@ -5,7 +5,11 @@ from functools import reduce
import pytest import pytest
import torch import torch
import math import math
from lightop import ds_cat import vllm.envs as envs
if envs.VLLM_USE_LIGHTOP:
from lightop import ds_cat
def test_concat_Acc_prefill(shape_pair, dim): def test_concat_Acc_prefill(shape_pair, dim):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment