Commit fd8764b3 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v-0.11.0-pa' into 'v0.11.0-dev'

[PD][Feat]支持fa_pa kvcahe类型模型推理

See merge request dcutoolkit/deeplearing/vllm!317
parents fd8e4a76 2241085d
...@@ -181,25 +181,34 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -181,25 +181,34 @@ class P2pNcclConnector(KVConnectorBase_V1):
None. The function modifies `layer` in-place. None. The function modifies `layer` in-place.
""" """
if (isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()) if (isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values())
or layer.shape[1] == 2): # MLA or FlashInfer or (not isinstance(layer, tuple))): # MLA or FlashInfer
num_block = kv_cache.shape[0] num_block = kv_cache.shape[0]
self.check_tensors_except_dim(layer, kv_cache, 0) self.check_tensors_except_dim(layer, kv_cache, 0)
if len(block_ids) == num_block: if len(block_ids) == num_block:
layer[block_ids, ...] = kv_cache layer[block_ids, ...] = kv_cache
else: else:
layer[block_ids[:num_block], ...] = kv_cache layer[block_ids[:num_block], ...] = kv_cache
logger.warning( logger.warning(
"🚧kv_cache does not match, block_ids:%d, " "🚧kv_cache does not match, block_ids:%d, "
"num_block:%d, request_id:%s", len(block_ids), "num_block:%d, request_id:%s", len(block_ids),
num_block, request_id) num_block, request_id)
#elif layer.shape[0] == 2: # FlashAttention
elif layer.shape[0] == 2: # FlashAttention else:
num_block = kv_cache.shape[1] num_block = kv_cache.shape[1]
self.check_tensors_except_dim(layer, kv_cache, 1) #self.check_tensors_except_dim(layer, kv_cache, 1)
if len(block_ids) == num_block: if len(block_ids) == num_block:
layer[:, block_ids, ...] = kv_cache #layer[:, block_ids, ...] = kv_cache
k_ = kv_cache[0].permute(0, 2, 1, 3)
v_ = kv_cache[1].permute(0, 2, 3, 1)
layer[0][block_ids, ...] = k_
layer[1][block_ids, ...] = v_
else: else:
layer[:, block_ids[:num_block], ...] = kv_cache #layer[:, block_ids[:num_block], ...] = kv_cache
k_ = kv_cache[0].permute(0, 2, 1, 3)
v_ = kv_cache[1].permute(0, 2, 3, 1)
layer[0][block_ids[:num_block], ...] = k_
layer[1][block_ids[:num_block], ...] = v_
logger.warning( logger.warning(
"🚧kv_cache does not match, block_ids:%d, " "🚧kv_cache does not match, block_ids:%d, "
"num_block:%d, request_id:%s", len(block_ids), "num_block:%d, request_id:%s", len(block_ids),
...@@ -304,11 +313,19 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -304,11 +313,19 @@ class P2pNcclConnector(KVConnectorBase_V1):
Returns None if the layout is unsupported. Returns None if the layout is unsupported.
""" """
if (isinstance(attn_metadata, MLACommonMetadata) if (isinstance(attn_metadata, MLACommonMetadata)
or layer.shape[1] == 2): # MLA or FlashInfer or not isinstance(layer, tuple)): # MLA or FlashInfer
return layer[block_ids, ...] return layer[block_ids, ...]
if layer.shape[0] == 2: # FlashAttention #if layer.shape[0] == 2: # FlashAttention
return layer[:, block_ids, ...] # return layer[:, block_ids, ...]
else:
k = layer[0] #(num_blocks, num_kv_heads, block_size, head_size)
v = layer[1] #(num_blocks, num_kv_heads, head_size, block_size)
k = k.permute(0,2,1,3)
v = v.permute(0,3,1,2)
kv = torch.stack([k, v], dim=0).contiguous()
return kv[:, block_ids, ...]
return None return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment