Commit e0fdf4e8 authored by zhuwenwen's avatar zhuwenwen
Browse files

update cat kernel

parent f331f103
...@@ -217,9 +217,6 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, ...@@ -217,9 +217,6 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
if envs.VLLM_USE_OPT_CAT:
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
try: try:
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import flash_attn_varlen_func
is_vllm_fa = True is_vllm_fa = True
...@@ -932,6 +929,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -932,6 +929,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if envs.VLLM_USE_OPT_CAT: if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024: if k_nope.shape[0] > 1024:
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)), k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2) dim=2)
else: else:
...@@ -997,6 +995,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -997,6 +995,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if envs.VLLM_USE_OPT_CAT: if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024: if k_nope.shape[0] > 1024:
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)), k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2) dim=2)
else: else:
......
...@@ -21,8 +21,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec ...@@ -21,8 +21,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from vllm import envs from vllm import envs
if envs.VLLM_USE_OPT_CAT:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -170,6 +168,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -170,6 +168,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
if envs.VLLM_USE_OPT_CAT: if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] < 1024: if q_nope.shape[0] < 1024:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
q = concat_helper_decode(q_nope, q_pe, dim=2)\ q = concat_helper_decode(q_nope, q_pe, dim=2)\
.unsqueeze(1) .unsqueeze(1)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment