"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "8d0e36b5d4f6e6f35a74bbb1c7974ed1e082242d"
Unverified Commit ef5a2268 authored by Chendi.Xue's avatar Chendi.Xue Committed by GitHub
Browse files

[PD][HeteroArch]Fix accuracy issue with CPU_ATTN as Decoder and Flash_ATTN as prefiller (#38935)


Signed-off-by: default avatarChendi Xue <chendi.xue@intel.com>
parent aec18492
...@@ -523,6 +523,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): ...@@ -523,6 +523,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
kv_cache_layout="HND", kv_cache_layout="HND",
block_size=self.block_size, block_size=self.block_size,
ssm_sizes=(0, 0), ssm_sizes=(0, 0),
attn_backend_name=self.backend_name,
), ),
remote_tp_rank=remote_tp_rank, remote_tp_rank=remote_tp_rank,
remote_tp_size=remote_tp_size, remote_tp_size=remote_tp_size,
...@@ -972,6 +973,7 @@ class TestNixlHandshake: ...@@ -972,6 +973,7 @@ class TestNixlHandshake:
kv_cache_layout=mismatched_layout, kv_cache_layout=mismatched_layout,
block_size=worker.block_size, block_size=worker.block_size,
ssm_sizes=(0, 0), ssm_sizes=(0, 0),
attn_backend_name=worker.backend_name,
) )
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
...@@ -1028,6 +1030,7 @@ class TestNixlHandshake: ...@@ -1028,6 +1030,7 @@ class TestNixlHandshake:
kv_cache_layout="HND", kv_cache_layout="HND",
block_size=worker.block_size, block_size=worker.block_size,
ssm_sizes=(0, 0), ssm_sizes=(0, 0),
attn_backend_name=worker.backend_name,
) )
# We don't check layout for homogeneous TP and MLA for now, as the # We don't check layout for homogeneous TP and MLA for now, as the
...@@ -2347,6 +2350,7 @@ def test_compatibility_hash_validation( ...@@ -2347,6 +2350,7 @@ def test_compatibility_hash_validation(
kv_cache_layout="HND", kv_cache_layout="HND",
block_size=prefill_block_size, block_size=prefill_block_size,
ssm_sizes=(0, 0), ssm_sizes=(0, 0),
attn_backend_name=decode_worker.backend_name,
) )
handshake_payload = NixlHandshakePayload( handshake_payload = NixlHandshakePayload(
compatibility_hash=remote_hash, compatibility_hash=remote_hash,
......
...@@ -173,6 +173,7 @@ class NixlAgentMetadata: ...@@ -173,6 +173,7 @@ class NixlAgentMetadata:
kv_cache_layout: str kv_cache_layout: str
block_size: int block_size: int
ssm_sizes: tuple[int, int] ssm_sizes: tuple[int, int]
attn_backend_name: str
@dataclass @dataclass
...@@ -1116,6 +1117,7 @@ class NixlConnectorWorker: ...@@ -1116,6 +1117,7 @@ class NixlConnectorWorker:
self.num_blocks = kv_cache_config.num_blocks self.num_blocks = kv_cache_config.num_blocks
self.enable_permute_local_kv = False self.enable_permute_local_kv = False
self.enable_heterogeneous_attn_post_process = False
# KV Caches and nixl tracking data. # KV Caches and nixl tracking data.
self.device_type = current_platform.device_type self.device_type = current_platform.device_type
...@@ -1776,6 +1778,7 @@ class NixlConnectorWorker: ...@@ -1776,6 +1778,7 @@ class NixlConnectorWorker:
else self.host_buffer_kv_cache_layout, else self.host_buffer_kv_cache_layout,
block_size=self.block_size, block_size=self.block_size,
ssm_sizes=self._mamba_ssm_size, ssm_sizes=self._mamba_ssm_size,
attn_backend_name=self.backend_name,
) )
# Wrap metadata in payload with hash for defensive decoding # Wrap metadata in payload with hash for defensive decoding
assert self.compat_hash is not None assert self.compat_hash is not None
...@@ -2369,6 +2372,21 @@ class NixlConnectorWorker: ...@@ -2369,6 +2372,21 @@ class NixlConnectorWorker:
"Or enable experimental feature to use HND to NHD support by " "Or enable experimental feature to use HND to NHD support by "
"setting 'enable_permute_local_kv'=True in --kv-transfer-config." "setting 'enable_permute_local_kv'=True in --kv-transfer-config."
) )
# if remote_agent used attn is not same as local,
# hint heterogenuous attn post process
if (
nixl_agent_meta.attn_backend_name != self.backend_name
and self.backend_name in ["CPU_ATTN"]
):
if self._is_hma_required:
raise RuntimeError(
"heterogeneous attn post process is not supported with HMA"
)
logger.info(
"[Experimental] CPU_ATTN backend is used, "
"hint heterogeneous attn post process"
)
self.enable_heterogeneous_attn_post_process = True
# Heterogeneous TP requires head-splitting, which only works with # Heterogeneous TP requires head-splitting, which only works with
# HND layout. MLA and replicated-KV cases don't split on heads. # HND layout. MLA and replicated-KV cases don't split on heads.
...@@ -2542,6 +2560,28 @@ class NixlConnectorWorker: ...@@ -2542,6 +2560,28 @@ class NixlConnectorWorker:
cache, indices, block_size_ratio cache, indices, block_size_ratio
) )
def post_process_device_kv_on_receive_heterogeneous_attn(
self, block_ids: list[int]
):
"""
Post process device kv cache after receiving from remote
for heterogeneous attention.
"""
assert self.enable_heterogeneous_attn_post_process
indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long)
for _, cache_or_caches in self.device_kv_caches.items():
blocks_to_update = cache_or_caches.index_select(1, indices)
current_platform.pack_kv_cache(
key=blocks_to_update[0],
value=blocks_to_update[1],
key_cache=cache_or_caches[0],
value_cache=cache_or_caches[1],
block_ids=block_ids,
indices=indices,
)
def get_finished(self) -> tuple[set[str], set[str]]: def get_finished(self) -> tuple[set[str], set[str]]:
""" """
Get requests that are done sending or recving on this specific worker. Get requests that are done sending or recving on this specific worker.
...@@ -2566,6 +2606,7 @@ class NixlConnectorWorker: ...@@ -2566,6 +2606,7 @@ class NixlConnectorWorker:
) )
block_ids_for_blocksize_post_process = defaultdict(list) block_ids_for_blocksize_post_process = defaultdict(list)
block_ids_for_heterogeneous_attn_post_process = list[list[int]]()
for req_id in done_recving: for req_id in done_recving:
# clean up metadata for completed requests # clean up metadata for completed requests
meta = self._recving_metadata.pop(req_id, None) meta = self._recving_metadata.pop(req_id, None)
...@@ -2585,12 +2626,20 @@ class NixlConnectorWorker: ...@@ -2585,12 +2626,20 @@ class NixlConnectorWorker:
block_ids_for_blocksize_post_process[block_size_ratio].append( block_ids_for_blocksize_post_process[block_size_ratio].append(
meta.local_physical_block_ids[0] meta.local_physical_block_ids[0]
) )
# post processing for heterogeneous attention
if self.enable_heterogeneous_attn_post_process:
block_ids_for_heterogeneous_attn_post_process.append(
meta.local_physical_block_ids[0]
)
for ( for (
block_size_ratio, block_size_ratio,
block_ids_list, block_ids_list,
) in block_ids_for_blocksize_post_process.items(): ) in block_ids_for_blocksize_post_process.items():
self.post_process_device_kv_on_receive(block_size_ratio, block_ids_list) self.post_process_device_kv_on_receive(block_size_ratio, block_ids_list)
for block_ids in block_ids_for_heterogeneous_attn_post_process:
self.post_process_device_kv_on_receive_heterogeneous_attn(block_ids)
# Handle timeout to avoid stranding blocks on remote. # Handle timeout to avoid stranding blocks on remote.
now = time.perf_counter() now = time.perf_counter()
while self._reqs_to_send: while self._reqs_to_send:
......
...@@ -520,3 +520,43 @@ class CpuPlatform(Platform): ...@@ -520,3 +520,43 @@ class CpuPlatform(Platform):
import vllm._C # noqa: F401 import vllm._C # noqa: F401
except ImportError as e: except ImportError as e:
logger.warning("Failed to import from vllm._C: %r", e) logger.warning("Failed to import from vllm._C: %r", e)
@classmethod
def pack_kv_cache(
cls,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_ids: list[int],
indices: torch.Tensor,
) -> None:
"""
Rewrite the kv cache shape for the current platform.
"""
# Import lazily: cpu_attn pulls in _custom_ops, which needs a fully
# initialized vllm.platforms (avoid circular import while CpuPlatform loads).
from vllm._custom_ops import cpu_attn_reshape_and_cache
from vllm.v1.attention.backends.cpu_attn import _get_attn_isa
dtype = key.dtype
# For CPU_ATTN, the shape is [N, num_kv_heads, block_size, head_size]
_, _, block_size, head_size = key_cache.shape
key = key.permute(0, 2, 1, 3).flatten(0, 1)
value = value.permute(0, 2, 1, 3).flatten(0, 1)
isa = _get_attn_isa(dtype, block_size, head_size)
block_offsets = torch.arange(block_size, device="cpu", dtype=torch.long)
num_blocks = len(block_ids)
slot_mapping = (
block_offsets.reshape(1, block_size)
+ indices.reshape(num_blocks, 1) * block_size
).flatten()
cpu_attn_reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping,
isa,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment