Commit a4fc4d7e authored by zhuwenwen's avatar zhuwenwen
Browse files

update cat kernel

parent 9f087f8b
...@@ -226,9 +226,6 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, ...@@ -226,9 +226,6 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
if envs.VLLM_USE_OPT_CAT:
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
try: try:
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import flash_attn_varlen_func
is_vllm_fa = True is_vllm_fa = True
...@@ -1399,6 +1396,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1399,6 +1396,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if envs.VLLM_USE_OPT_CAT: if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024: if k_nope.shape[0] > 1024:
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)), k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2) dim=2)
else: else:
...@@ -1560,6 +1558,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1560,6 +1558,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if envs.VLLM_USE_OPT_CAT: if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024: if k_nope.shape[0] > 1024:
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)), k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2) dim=2)
else: else:
......
...@@ -22,8 +22,6 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport ...@@ -22,8 +22,6 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm import envs from vllm import envs
if envs.VLLM_USE_OPT_CAT:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -197,6 +195,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -197,6 +195,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
# if envs.VLLM_USE_OPT_CAT: # if envs.VLLM_USE_OPT_CAT:
# if q_nope.shape[0] < 1024: # if q_nope.shape[0] < 1024:
# from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
# q = concat_helper_decode(q_nope, q_pe, dim=2)\ # q = concat_helper_decode(q_nope, q_pe, dim=2)\
# .unsqueeze(1) # .unsqueeze(1)
# else: # else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment