Commit 55989b60 authored by zhuwenwen's avatar zhuwenwen
Browse files

[PD][Feat]支持fa_pa kvcahe类型模型推理

parent 451af742
......@@ -201,7 +201,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
None. The function modifies `layer` in-place.
"""
if (
isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values() or layer.shape[1] == 2
isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()) or (not isinstance(layer, tuple))
): # MLA or FlashInfer
num_block = kv_cache.shape[0]
self.check_tensors_except_dim(layer, kv_cache, 0)
......@@ -217,13 +217,23 @@ class P2pNcclConnector(KVConnectorBase_V1):
request_id,
)
elif layer.shape[0] == 2: # FlashAttention
# elif layer.shape[0] == 2: # FlashAttention
else:
num_block = kv_cache.shape[1]
self.check_tensors_except_dim(layer, kv_cache, 1)
# self.check_tensors_except_dim(layer, kv_cache, 1)
if len(block_ids) == num_block:
layer[:, block_ids, ...] = kv_cache
# layer[:, block_ids, ...] = kv_cache
#layer[:, block_ids, ...] = kv_cache
k_ = kv_cache[0].permute(0, 2, 1, 3)
v_ = kv_cache[1].permute(0, 2, 3, 1)
layer[0][block_ids, ...] = k_
layer[1][block_ids, ...] = v_
else:
layer[:, block_ids[:num_block], ...] = kv_cache
# layer[:, block_ids[:num_block], ...] = kv_cache
k_ = kv_cache[0].permute(0, 2, 1, 3)
v_ = kv_cache[1].permute(0, 2, 3, 1)
layer[0][block_ids[:num_block], ...] = k_
layer[1][block_ids[:num_block], ...] = v_
logger.warning(
"🚧kv_cache does not match, block_ids:%d, "
"num_block:%d, request_id:%s",
......@@ -336,12 +346,20 @@ class P2pNcclConnector(KVConnectorBase_V1):
Returns None if the layout is unsupported.
"""
if (
isinstance(attn_metadata, MLACommonMetadata) or layer.shape[1] == 2
isinstance(attn_metadata, MLACommonMetadata) or not isinstance(layer, tuple)
): # MLA or FlashInfer
return layer[block_ids, ...]
if layer.shape[0] == 2: # FlashAttention
return layer[:, block_ids, ...]
# if layer.shape[0] == 2: # FlashAttention
# return layer[:, block_ids, ...]
else:
k = layer[0] #(num_blocks, num_kv_heads, block_size, head_size)
v = layer[1] #(num_blocks, num_kv_heads, head_size, block_size)
k = k.permute(0,2,1,3)
v = v.permute(0,3,1,2)
kv = torch.stack([k, v], dim=0).contiguous()
return kv[:, block_ids, ...]
return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment