Commit e0fdf4e8 authored by zhuwenwen's avatar zhuwenwen
Browse files

update cat kernel

parent f331f103
......@@ -217,9 +217,6 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
if envs.VLLM_USE_OPT_CAT:
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
is_vllm_fa = True
......@@ -932,6 +929,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024:
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2)
else:
......@@ -997,6 +995,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024:
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2)
else:
......
......@@ -21,8 +21,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from vllm import envs
if envs.VLLM_USE_OPT_CAT:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
logger = init_logger(__name__)
......@@ -170,6 +168,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] < 1024:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
q = concat_helper_decode(q_nope, q_pe, dim=2)\
.unsqueeze(1)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment