Commit 55989b60 authored by zhuwenwen's avatar zhuwenwen
Browse files

[PD][Feat]支持fa_pa kvcahe类型模型推理

parent 451af742
...@@ -201,7 +201,7 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -201,7 +201,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
None. The function modifies `layer` in-place. None. The function modifies `layer` in-place.
""" """
if ( if (
isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values() or layer.shape[1] == 2 isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()) or (not isinstance(layer, tuple))
): # MLA or FlashInfer ): # MLA or FlashInfer
num_block = kv_cache.shape[0] num_block = kv_cache.shape[0]
self.check_tensors_except_dim(layer, kv_cache, 0) self.check_tensors_except_dim(layer, kv_cache, 0)
...@@ -217,13 +217,23 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -217,13 +217,23 @@ class P2pNcclConnector(KVConnectorBase_V1):
request_id, request_id,
) )
elif layer.shape[0] == 2: # FlashAttention # elif layer.shape[0] == 2: # FlashAttention
else:
num_block = kv_cache.shape[1] num_block = kv_cache.shape[1]
self.check_tensors_except_dim(layer, kv_cache, 1) # self.check_tensors_except_dim(layer, kv_cache, 1)
if len(block_ids) == num_block: if len(block_ids) == num_block:
layer[:, block_ids, ...] = kv_cache # layer[:, block_ids, ...] = kv_cache
#layer[:, block_ids, ...] = kv_cache
k_ = kv_cache[0].permute(0, 2, 1, 3)
v_ = kv_cache[1].permute(0, 2, 3, 1)
layer[0][block_ids, ...] = k_
layer[1][block_ids, ...] = v_
else: else:
layer[:, block_ids[:num_block], ...] = kv_cache # layer[:, block_ids[:num_block], ...] = kv_cache
k_ = kv_cache[0].permute(0, 2, 1, 3)
v_ = kv_cache[1].permute(0, 2, 3, 1)
layer[0][block_ids[:num_block], ...] = k_
layer[1][block_ids[:num_block], ...] = v_
logger.warning( logger.warning(
"🚧kv_cache does not match, block_ids:%d, " "🚧kv_cache does not match, block_ids:%d, "
"num_block:%d, request_id:%s", "num_block:%d, request_id:%s",
...@@ -336,12 +346,20 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -336,12 +346,20 @@ class P2pNcclConnector(KVConnectorBase_V1):
Returns None if the layout is unsupported. Returns None if the layout is unsupported.
""" """
if ( if (
isinstance(attn_metadata, MLACommonMetadata) or layer.shape[1] == 2 isinstance(attn_metadata, MLACommonMetadata) or not isinstance(layer, tuple)
): # MLA or FlashInfer ): # MLA or FlashInfer
return layer[block_ids, ...] return layer[block_ids, ...]
if layer.shape[0] == 2: # FlashAttention # if layer.shape[0] == 2: # FlashAttention
return layer[:, block_ids, ...] # return layer[:, block_ids, ...]
else:
k = layer[0] #(num_blocks, num_kv_heads, block_size, head_size)
v = layer[1] #(num_blocks, num_kv_heads, head_size, block_size)
k = k.permute(0,2,1,3)
v = v.permute(0,3,1,2)
kv = torch.stack([k, v], dim=0).contiguous()
return kv[:, block_ids, ...]
return None return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment