Commit b98167fc authored by xiabo's avatar xiabo Committed by zhangzbb
Browse files

同步官方v0.15.1的kvcache的处理方式。可以参照官方发pr:https://github.com/vllm-project/vllm/pull/23536/changes

parent ce47a56e
...@@ -34,28 +34,20 @@ logger = init_logger(__name__) ...@@ -34,28 +34,20 @@ logger = init_logger(__name__)
class ReqMeta: class ReqMeta:
# Request Id # Request Id
request_id: str request_id: str
# Request tokens # Request block ids
token_ids: torch.Tensor block_ids: torch.Tensor
# Slot mappings, should have the same length as token_ids # Request num tokens
slot_mapping: torch.Tensor num_tokens: int
slot_mapping_device: torch.Tensor = None
@staticmethod @staticmethod
def make_meta(request_id: str, token_ids: list[int], block_ids: list[int], def make_meta(
block_size: int) -> "ReqMeta": request_id: str, token_ids: list[int], block_ids: list[int], block_size: int
valid_num_tokens = len(token_ids) ) -> "ReqMeta":
token_ids_tensor = torch.tensor(token_ids)
block_ids_tensor = torch.tensor(block_ids) block_ids_tensor = torch.tensor(block_ids)
num_blocks = block_ids_tensor.shape[0]
block_offsets = torch.arange(0, block_size)
slot_mapping = block_offsets.reshape((1, block_size)) + \
block_ids_tensor.reshape((num_blocks, 1)) * block_size
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
return ReqMeta( return ReqMeta(
request_id=request_id, request_id=request_id,
token_ids=token_ids_tensor, block_ids=block_ids_tensor,
slot_mapping=slot_mapping, num_tokens=len(token_ids),
) )
...@@ -74,7 +66,8 @@ class DuSwiftConnectorMetadata(KVConnectorMetadata): ...@@ -74,7 +66,8 @@ class DuSwiftConnectorMetadata(KVConnectorMetadata):
block_size: int, block_size: int,
) -> None: ) -> None:
self.requests.append( self.requests.append(
ReqMeta.make_meta(request_id, token_ids, block_ids, block_size)) ReqMeta.make_meta(request_id, token_ids, block_ids, block_size)
)
class DuSwiftConnector(KVConnectorBase_V1): class DuSwiftConnector(KVConnectorBase_V1):
...@@ -190,63 +183,62 @@ class DuSwiftConnector(KVConnectorBase_V1): ...@@ -190,63 +183,62 @@ class DuSwiftConnector(KVConnectorBase_V1):
if attn_metadata is None: if attn_metadata is None:
return return
def inject_kv_into_layer( def inject_kv_into_layer(
dst_kv_cache_layer: torch.Tensor, layer: torch.Tensor,
src_kv_cache: torch.Tensor, kv_cache: torch.Tensor,
slot_mapping: torch.Tensor, block_ids: torch.Tensor,
request_id: str, request_id: str,
) -> None: ) -> None:
"""Inject the KV cache into the layer. """
Inject KV cache data into a given attention layer tensor.
This function updates `layer` in-place with values from `kv_cache`,
handling different backend layouts:
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
indexed along the first dimension.
- FlashAttention: KV tensors are indexed along the second
dimension.
If the number of provided block IDs does not match the number of KV
blocks, only the overlapping portion is updated, and a warning is
logged.
Args: Args:
dst_kv_cache_layer (torch.Tensor): the destination KV cache layer (torch.Tensor): The attention layer KV tensor to update.
layer. In shape [2, num_pages, page_size, xxx] if not kv_cache (torch.Tensor): The KV cache tensor to inject.
using MLA, [num_pages, page_size, xxx] otherwise. block_ids (torch.Tensor): Indices of the blocks to update.
src_kv_cache (torch.Tensor): the source KV cache. In shape request_id (str): Request identifier used for logging.
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
otherwise. Returns:
slot_mapping (torch.Tensor): the slot mapping. In shape None. The function modifies `layer` in-place.
[num_tokens].
request_id (str): request id for log
""" """
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()) or layer.ndim == 3:
if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()) or dst_kv_cache_layer.ndim == 3: num_block = kv_cache.shape[0]
num_pages = dst_kv_cache_layer_shape[0] self.check_tensors_except_dim(layer, kv_cache, 0)
page_size = dst_kv_cache_layer_shape[1] if len(block_ids) == num_block:
dst_kv_cache_layer = dst_kv_cache_layer.reshape( layer[block_ids, ...] = kv_cache
num_pages * page_size, -1)
self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache,
0)
num_token = src_kv_cache.shape[0]
if len(slot_mapping) == num_token:
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
else: else:
dst_kv_cache_layer[slot_mapping[:num_token], layer[block_ids[:num_block], ...] = kv_cache
...] = src_kv_cache
logger.warning( logger.warning(
"🚧src_kv_cache does not match, num_slot:%d, " "🚧kv_cache does not match, block_ids:%d, "
"num_token:%d, request_id:%s", len(slot_mapping), "num_block:%d, request_id:%s",
num_token, request_id) len(block_ids),
num_block,
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) request_id,
)
else: else:
num_pages = dst_kv_cache_layer_shape[1] num_block = kv_cache.shape[1]
page_size = dst_kv_cache_layer_shape[2] self.check_tensors_except_dim(layer, kv_cache, 1)
dst_kv_cache_layer = dst_kv_cache_layer.reshape( if len(block_ids) == num_block:
2, num_pages * page_size, -1) layer[:, block_ids, ...] = kv_cache
self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache,
1)
num_token = src_kv_cache.shape[1]
if len(slot_mapping) == num_token:
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
else: else:
dst_kv_cache_layer[:, slot_mapping[:num_token], layer[:, block_ids[:num_block], ...] = kv_cache
...] = src_kv_cache
logger.warning( logger.warning(
"🚧src_kv_cache does not match, num_slot:%d, " "🚧kv_cache does not match, block_ids:%d, "
"num_token:%d, request_id:%s", len(slot_mapping), "num_block:%d, request_id:%s",
num_token, request_id) len(block_ids),
num_block,
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) request_id,
)
# Get the metadata # Get the metadata
metadata: KVConnectorMetadata = \ metadata: KVConnectorMetadata = \
...@@ -280,7 +272,7 @@ class DuSwiftConnector(KVConnectorBase_V1): ...@@ -280,7 +272,7 @@ class DuSwiftConnector(KVConnectorBase_V1):
request.request_id) request.request_id)
continue continue
inject_kv_into_layer(kv_cache_layer, kv_cache, inject_kv_into_layer(kv_cache_layer, kv_cache,
request.slot_mapping, request.request_id) request.block_ids, request.request_id)
tensor_id = request.request_id + "#" + layer_name tensor_id = request.request_id + "#" + layer_name
if tensor_id in self.du_swift_engine.recv_store: if tensor_id in self.du_swift_engine.recv_store:
tensor = self.du_swift_engine.recv_store.pop(tensor_id, None) tensor = self.du_swift_engine.recv_store.pop(tensor_id, None)
...@@ -383,20 +375,28 @@ class DuSwiftConnector(KVConnectorBase_V1): ...@@ -383,20 +375,28 @@ class DuSwiftConnector(KVConnectorBase_V1):
def extract_kv_from_layer( def extract_kv_from_layer(
layer: torch.Tensor, layer: torch.Tensor,
slot_mapping: torch.Tensor, block_ids: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
"""Extract the KV cache from the layer. """
Extract KV cache slices from a given attention layer tensor.
This function handles multiple backend layouts:
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
indexed along the first dimension.
- FlashAttention: KV tensors are indexed along the second
dimension.
Assume the shape of the layer is (2, num_pages, page_size, xxx) Args:
if MLA is not used, and (num_pages, page_size, xxx) otherwise. layer (torch.Tensor): The KV cache from the attention layer.
block_ids (torch.Tensor): Indices of blocks to extract.
Returns:
torch.Tensor: A tensor containing the extracted KV slices.
Returns None if the layout is unsupported.
""" """
if isinstance(attn_metadata, MLACommonMetadata) or layer.ndim == 3: if isinstance(attn_metadata, MLACommonMetadata) or layer.ndim == 3:
num_pages, page_size = layer.shape[0], layer.shape[1] return layer[block_ids, ...]
return layer.reshape(num_pages * page_size, -1)[slot_mapping, return layer[:, block_ids, ...]
...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
...]
connector_metadata = self._get_connector_metadata() connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, DuSwiftConnectorMetadata) assert isinstance(connector_metadata, DuSwiftConnectorMetadata)
...@@ -443,7 +443,7 @@ class DuSwiftConnector(KVConnectorBase_V1): ...@@ -443,7 +443,7 @@ class DuSwiftConnector(KVConnectorBase_V1):
p_ip, p_port = self.parse_request_id(request_id, False) p_ip, p_port = self.parse_request_id(request_id, False)
remote_address = ip + ":" + str(port + self._rank) remote_address = ip + ":" + str(port + self._rank)
# pd_pair_id = p_ip + ":" + p_port + "_" + ip + ":" + port # pd_pair_id = p_ip + ":" + p_port + "_" + ip + ":" + port
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping) kv_cache = extract_kv_from_layer(kv_layer, request.block_ids)
pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size
) % self.parallel_config.pipeline_parallel_size ) % self.parallel_config.pipeline_parallel_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment