Unverified Commit ec965569 authored by Yihua Cheng's avatar Yihua Cheng Committed by GitHub
Browse files

[KV connector][LMCache] Only record the cuda event when there are request to store/load (#30814)


Signed-off-by: default avatarApostaC <yihua98@uchicago.edu>
parent 82dc338a
...@@ -262,6 +262,7 @@ class LMCacheMPWorkerAdapter: ...@@ -262,6 +262,7 @@ class LMCacheMPWorkerAdapter:
): ):
keys = [] keys = []
block_ids = [] block_ids = []
for op in ops: for op in ops:
keys.extend(self._block_hashes_to_keys(op.block_hashes)) keys.extend(self._block_hashes_to_keys(op.block_hashes))
block_ids.extend(op.block_ids) block_ids.extend(op.block_ids)
......
...@@ -24,6 +24,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration import ( ...@@ -24,6 +24,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration import (
) )
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.request import RequestStatus
from vllm.v1.utils import ConstantList from vllm.v1.utils import ConstantList
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -211,7 +212,7 @@ class LMCacheMPRequestTracker: ...@@ -211,7 +212,7 @@ class LMCacheMPRequestTracker:
""" """
self.num_stored_blocks += num_new_blocks self.num_stored_blocks += num_new_blocks
def update_block_ids( def append_block_ids(
self, self,
new_block_ids: list[int], new_block_ids: list[int],
): ):
...@@ -455,10 +456,6 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -455,10 +456,6 @@ class LMCacheMPConnector(KVConnectorBase_V1):
metadata = self._get_connector_metadata() metadata = self._get_connector_metadata()
assert isinstance(metadata, LMCacheMPConnectorMetadata) assert isinstance(metadata, LMCacheMPConnectorMetadata)
with torch.cuda.stream(torch.cuda.current_stream()):
event = torch.cuda.Event(interprocess=True)
event.record()
request_ids = [] request_ids = []
ops = [] ops = []
...@@ -468,10 +465,14 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -468,10 +465,14 @@ class LMCacheMPConnector(KVConnectorBase_V1):
request_ids.append(meta.request_id) request_ids.append(meta.request_id)
ops.append(meta.op) ops.append(meta.op)
if len(request_ids) > 0: if len(request_ids) == 0:
self.worker_adapter.batched_submit_retrieve_requests( return
request_ids, ops, event
) with torch.cuda.stream(torch.cuda.current_stream()):
event = torch.cuda.Event(interprocess=True)
event.record()
self.worker_adapter.batched_submit_retrieve_requests(request_ids, ops, event)
def wait_for_layer_load(self, layer_name: str) -> None: def wait_for_layer_load(self, layer_name: str) -> None:
""" """
...@@ -518,10 +519,6 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -518,10 +519,6 @@ class LMCacheMPConnector(KVConnectorBase_V1):
metadata = self._get_connector_metadata() metadata = self._get_connector_metadata()
assert isinstance(metadata, LMCacheMPConnectorMetadata) assert isinstance(metadata, LMCacheMPConnectorMetadata)
with torch.cuda.stream(torch.cuda.current_stream()):
event = torch.cuda.Event(interprocess=True)
event.record()
request_ids = [] request_ids = []
ops = [] ops = []
for meta in metadata.requests: for meta in metadata.requests:
...@@ -530,8 +527,14 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -530,8 +527,14 @@ class LMCacheMPConnector(KVConnectorBase_V1):
request_ids.append(meta.request_id) request_ids.append(meta.request_id)
ops.append(meta.op) ops.append(meta.op)
if len(request_ids) > 0: if len(request_ids) == 0:
self.worker_adapter.batched_submit_store_requests(request_ids, ops, event) return
with torch.cuda.stream(torch.cuda.current_stream()):
event = torch.cuda.Event(interprocess=True)
event.record()
self.worker_adapter.batched_submit_store_requests(request_ids, ops, event)
def get_finished( def get_finished(
self, finished_req_ids: set[str] self, finished_req_ids: set[str]
...@@ -627,6 +630,9 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -627,6 +630,9 @@ class LMCacheMPConnector(KVConnectorBase_V1):
into account. into account.
""" """
tracker = self._get_or_create_request_tracker(request) tracker = self._get_or_create_request_tracker(request)
# TODO: support loading KV for preempted requests in the future
if request.status == RequestStatus.PREEMPTED:
return 0, False
self.scheduler_adapter.maybe_submit_lookup_request( self.scheduler_adapter.maybe_submit_lookup_request(
request.request_id, convert_block_hashes_to_bytes(request.block_hashes) request.request_id, convert_block_hashes_to_bytes(request.block_hashes)
...@@ -683,7 +689,7 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -683,7 +689,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
# No matter we need to retrieve or not, we need to update # No matter we need to retrieve or not, we need to update
# the block ids into the tracker # the block ids into the tracker
tracker.update_block_ids(block_ids) tracker.append_block_ids(block_ids)
# Update the state of the tracker # Update the state of the tracker
condition = tracker.needs_retrieve() condition = tracker.needs_retrieve()
...@@ -866,7 +872,8 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -866,7 +872,8 @@ class LMCacheMPConnector(KVConnectorBase_V1):
# Update block ids # Update block ids
new_block_ids = reformat_block_ids(cached_reqs.new_block_ids[idx]) new_block_ids = reformat_block_ids(cached_reqs.new_block_ids[idx])
request_tracker.update_block_ids(new_block_ids) if request_id not in cached_reqs.resumed_req_ids:
request_tracker.append_block_ids(new_block_ids)
# Update new scheduled tokens # Update new scheduled tokens
num_new_tokens = cached_reqs.num_computed_tokens[idx] num_new_tokens = cached_reqs.num_computed_tokens[idx]
...@@ -889,6 +896,21 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -889,6 +896,21 @@ class LMCacheMPConnector(KVConnectorBase_V1):
self, request: "Request" self, request: "Request"
) -> LMCacheMPRequestTracker: ) -> LMCacheMPRequestTracker:
request_id = request.request_id request_id = request.request_id
# Remove the old trackers that is created before the preemption
if (
request.status == RequestStatus.PREEMPTED
and request_id in self.request_trackers
):
tracker = self.request_trackers[request_id]
# NOTE: since this function may be called multiple times
# for a single request (because get_num_new_matched_tokens
# may be called multiple times) for the same request, we
# will only do the remove if the tracker is not in the "fresh"
# state, i.e., PREFETCHING
if tracker.state != LMCacheMPRequestState.PREFETCHING:
self.request_trackers.pop(request_id)
if request_id not in self.request_trackers: if request_id not in self.request_trackers:
new_tracker = LMCacheMPRequestTracker(request) new_tracker = LMCacheMPRequestTracker(request)
self.request_trackers[request_id] = new_tracker self.request_trackers[request_id] = new_tracker
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment