Commit fd8764b3 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v-0.11.0-pa' into 'v0.11.0-dev'

[PD][Feat]支持fa_pa kvcahe类型模型推理

See merge request dcutoolkit/deeplearing/vllm!317
parents fd8e4a76 2241085d
......@@ -181,25 +181,34 @@ class P2pNcclConnector(KVConnectorBase_V1):
None. The function modifies `layer` in-place.
"""
if (isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values())
or layer.shape[1] == 2): # MLA or FlashInfer
or (not isinstance(layer, tuple))): # MLA or FlashInfer
num_block = kv_cache.shape[0]
self.check_tensors_except_dim(layer, kv_cache, 0)
if len(block_ids) == num_block:
layer[block_ids, ...] = kv_cache
else:
layer[block_ids[:num_block], ...] = kv_cache
logger.warning(
"🚧kv_cache does not match, block_ids:%d, "
"num_block:%d, request_id:%s", len(block_ids),
num_block, request_id)
elif layer.shape[0] == 2: # FlashAttention
#elif layer.shape[0] == 2: # FlashAttention
else:
num_block = kv_cache.shape[1]
self.check_tensors_except_dim(layer, kv_cache, 1)
#self.check_tensors_except_dim(layer, kv_cache, 1)
if len(block_ids) == num_block:
layer[:, block_ids, ...] = kv_cache
#layer[:, block_ids, ...] = kv_cache
k_ = kv_cache[0].permute(0, 2, 1, 3)
v_ = kv_cache[1].permute(0, 2, 3, 1)
layer[0][block_ids, ...] = k_
layer[1][block_ids, ...] = v_
else:
layer[:, block_ids[:num_block], ...] = kv_cache
#layer[:, block_ids[:num_block], ...] = kv_cache
k_ = kv_cache[0].permute(0, 2, 1, 3)
v_ = kv_cache[1].permute(0, 2, 3, 1)
layer[0][block_ids[:num_block], ...] = k_
layer[1][block_ids[:num_block], ...] = v_
logger.warning(
"🚧kv_cache does not match, block_ids:%d, "
"num_block:%d, request_id:%s", len(block_ids),
......@@ -304,11 +313,19 @@ class P2pNcclConnector(KVConnectorBase_V1):
Returns None if the layout is unsupported.
"""
if (isinstance(attn_metadata, MLACommonMetadata)
or layer.shape[1] == 2): # MLA or FlashInfer
or not isinstance(layer, tuple)): # MLA or FlashInfer
return layer[block_ids, ...]
if layer.shape[0] == 2: # FlashAttention
return layer[:, block_ids, ...]
#if layer.shape[0] == 2: # FlashAttention
# return layer[:, block_ids, ...]
else:
k = layer[0] #(num_blocks, num_kv_heads, block_size, head_size)
v = layer[1] #(num_blocks, num_kv_heads, head_size, block_size)
k = k.permute(0,2,1,3)
v = v.permute(0,3,1,2)
kv = torch.stack([k, v], dim=0).contiguous()
return kv[:, block_ids, ...]
return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment