Commit 1ecb8be9 authored by xuxz's avatar xuxz
Browse files

[PD]类与函数数据结构优化参数调整 && 支持小模型pd推理

parent 161789cb
......@@ -36,28 +36,20 @@ logger = init_logger(__name__)
class ReqMeta:
# Request Id
request_id: str
# Request tokens
token_ids: torch.Tensor
# Slot mappings, should have the same length as token_ids
slot_mapping: torch.Tensor
slot_mapping_device: torch.Tensor = None
# Request block ids
block_ids: torch.Tensor
# Request num tokens
num_tokens: int
@staticmethod
def make_meta(request_id: str, token_ids: list[int], block_ids: list[int],
block_size: int) -> "ReqMeta":
valid_num_tokens = len(token_ids)
token_ids_tensor = torch.tensor(token_ids)
block_ids_tensor = torch.tensor(block_ids)
num_blocks = block_ids_tensor.shape[0]
block_offsets = torch.arange(0, block_size)
slot_mapping = block_offsets.reshape((1, block_size)) + \
block_ids_tensor.reshape((num_blocks, 1)) * block_size
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
return ReqMeta(
request_id=request_id,
token_ids=token_ids_tensor,
slot_mapping=slot_mapping,
block_ids=block_ids_tensor,
num_tokens=len(token_ids),
)
......@@ -222,9 +214,9 @@ class DuSwiftConnector(KVConnectorBase_V1):
if attn_metadata is None:
return
def inject_kv_into_layer(
dst_kv_cache_layer: torch.Tensor,
src_kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
layer: torch.Tensor,
kv_cache: torch.Tensor,
block_ids: torch.Tensor,
request_id: str,
) -> None:
"""Inject the KV cache into the layer.
......@@ -240,45 +232,39 @@ class DuSwiftConnector(KVConnectorBase_V1):
[num_tokens].
request_id (str): request id for log
"""
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()):
num_pages = dst_kv_cache_layer_shape[0]
page_size = dst_kv_cache_layer_shape[1]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
num_pages * page_size, -1)
self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache,
0)
num_token = src_kv_cache.shape[0]
if len(slot_mapping) == num_token:
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
if (isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values())
or (not isinstance(layer, tuple))): # MLA or FlashInfer
num_block = kv_cache.shape[0]
self.check_tensors_except_dim(layer, kv_cache, 0)
if len(block_ids) == num_block:
layer[block_ids, ...] = kv_cache
else:
dst_kv_cache_layer[slot_mapping[:num_token],
...] = src_kv_cache
logger.warning(
"🚧src_kv_cache does not match, num_slot:%d, "
"num_token:%d, request_id:%s", len(slot_mapping),
num_token, request_id)
layer[block_ids[:num_block], ...] = kv_cache
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
logger.warning(
"🚧kv_cache does not match, block_ids:%d, "
"num_block:%d, request_id:%s", len(block_ids),
num_block, request_id)
#elif layer.shape[0] == 2: # FlashAttention
else:
num_pages = dst_kv_cache_layer_shape[1]
page_size = dst_kv_cache_layer_shape[2]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
2, num_pages * page_size, -1)
self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache,
1)
num_token = src_kv_cache.shape[1]
if len(slot_mapping) == num_token:
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
num_block = kv_cache.shape[1]
#self.check_tensors_except_dim(layer, kv_cache, 1)
if len(block_ids) == num_block:
#layer[:, block_ids, ...] = kv_cache
k_ = kv_cache[0].permute(0, 2, 1, 3)
v_ = kv_cache[1].permute(0, 2, 3, 1)
layer[0][block_ids, ...] = k_
layer[1][block_ids, ...] = v_
else:
dst_kv_cache_layer[:, slot_mapping[:num_token],
...] = src_kv_cache
#layer[:, block_ids[:num_block], ...] = kv_cache
k_ = kv_cache[0].permute(0, 2, 1, 3)
v_ = kv_cache[1].permute(0, 2, 3, 1)
layer[0][block_ids[:num_block], ...] = k_
layer[1][block_ids[:num_block], ...] = v_
logger.warning(
"🚧src_kv_cache does not match, num_slot:%d, "
"num_token:%d, request_id:%s", len(slot_mapping),
num_token, request_id)
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
"🚧kv_cache does not match, block_ids:%d, "
"num_block:%d, request_id:%s", len(block_ids),
num_block, request_id)
# Get the metadata
metadata: KVConnectorMetadata = \
......@@ -300,18 +286,17 @@ class DuSwiftConnector(KVConnectorBase_V1):
if kv_cache is None:
continue
kv_cache_layer = kv_cache[ \
forward_context.virtual_engine]
layer = kv_cache[forward_context.virtual_engine]
kv_cache = self.du_swift_engine.recv_tensor(
request.request_id + "#" + layer_name)
if kv_cache is None:
logger.warning("🚧src_kv_cache is None, %s",
request.request_id)
logger.warning("🚧kv_cache is None, %s", request.request_id)
continue
inject_kv_into_layer(kv_cache_layer, kv_cache,
request.slot_mapping, request.request_id)
inject_kv_into_layer(layer, kv_cache,
request.block_ids, request.request_id)
tensor_id = request.request_id + "#" + layer_name
if tensor_id in self.du_swift_engine.recv_store:
tensor = self.du_swift_engine.recv_store.pop(tensor_id, None)
......@@ -359,20 +344,27 @@ class DuSwiftConnector(KVConnectorBase_V1):
def extract_kv_from_layer(
layer: torch.Tensor,
slot_mapping: torch.Tensor,
block_ids: torch.Tensor,
) -> torch.Tensor:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
...]
if (isinstance(attn_metadata, MLACommonMetadata)
or not isinstance(layer, tuple)): # MLA or FlashInfer
return layer[block_ids, ...]
#if layer.shape[0] == 2: # FlashAttention
# return layer[:, block_ids, ...]
else:
k = layer[0] #(num_blocks, num_kv_heads, block_size, head_size)
v = layer[1] #(num_blocks, num_kv_heads, head_size, block_size)
k = k.permute(0,2,1,3)
v = v.permute(0,3,1,2)
kv = torch.stack([k, v], dim=0).contiguous()
return kv[:, block_ids, ...]
connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, DuSwiftConnectorMetadata)
......@@ -380,7 +372,7 @@ class DuSwiftConnector(KVConnectorBase_V1):
for request in connector_metadata.requests:
request_id = request.request_id
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
kv_cache = extract_kv_from_layer(kv_layer, request.block_ids)
pending = False
with self.du_swift_engine.req_status_cv:
if request_id not in self.du_swift_engine.req_status:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment