Unverified Commit 7227d061 authored by omerpaz95's avatar omerpaz95 Committed by GitHub
Browse files

[Metrics] [KVConnector] Add Offloading Connector metrics (#27942)



Added queries and hits metrics for the Offloading Connector.

Also added timing metrics for store and load operations, which take the
average time it takes to load/store, per-token.

The metrics are available from Prometheus and from the StatLogger.
Signed-off-by: default avataromerpaz95 <omerpaz95@gmail.com>
Co-authored-by: default avatarOmer Paz <Omer.Paz@ibm.com>
parent 14385c80
...@@ -16,6 +16,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole ...@@ -16,6 +16,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import (
OffloadingConnector, OffloadingConnector,
OffloadingConnectorMetadata, OffloadingConnectorMetadata,
OffloadingConnectorStats,
) )
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.utils.hashing import sha256 from vllm.utils.hashing import sha256
...@@ -86,7 +87,14 @@ class MockOffloadingHandler(OffloadingHandler): ...@@ -86,7 +87,14 @@ class MockOffloadingHandler(OffloadingHandler):
if job_id in self.waiting_jobs: if job_id in self.waiting_jobs:
self.waiting_jobs.remove(job_id) self.waiting_jobs.remove(job_id)
self.completed_jobs.append(job_id) self.completed_jobs.append(job_id)
self.completed_transfers.append((job_id, True)) result = TransferResult(
job_id=job_id,
success=True,
transfer_size=None,
transfer_time=None,
transfer_type=None,
)
self.completed_transfers.append(result)
def wait(self, job_ids: set[int]) -> None: def wait(self, job_ids: set[int]) -> None:
self.flushed_jobs |= job_ids self.flushed_jobs |= job_ids
...@@ -720,3 +728,144 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner): ...@@ -720,3 +728,144 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner):
# second request will use the GPU prefix cache # second request will use the GPU prefix cache
assert transfer_jobs == list(runner.offloading_spec.handler.transfer_specs) assert transfer_jobs == list(runner.offloading_spec.handler.transfer_specs)
class TestOffloadingConnectorStats:
"""Tests for OffloadingConnector stats reconstruction and operations."""
def test_build_kv_connector_stats_with_none(self):
"""Test that build_kv_connector_stats returns empty stats when given None."""
stats = OffloadingConnector.build_kv_connector_stats(data=None)
assert stats is not None
assert isinstance(stats, OffloadingConnectorStats)
assert len(stats.data) == 0
assert stats.is_empty()
def test_build_kv_connector_stats_with_empty_dict(self):
"""Test that build_kv_connector_stats returns empty stats with empty dict."""
stats = OffloadingConnector.build_kv_connector_stats(data={})
assert stats is not None
assert isinstance(stats, OffloadingConnectorStats)
assert len(stats.data) == 0
assert stats.is_empty()
def test_build_kv_connector_stats_reconstructs_offload_stats(self):
"""Test that OffloadingConnector stats are properly reconstructed with
correct data."""
serialized_data = {
"CPU_to_GPU": [
{"op_size": 16, "op_time": 1.0},
{"op_size": 8, "op_time": 0.5},
],
"GPU_to_CPU": [
{"op_size": 1, "op_time": 0.1},
{"op_size": 2, "op_time": 0.2},
],
}
stats = OffloadingConnector.build_kv_connector_stats(data=serialized_data)
offload_connector_stats = stats
assert isinstance(offload_connector_stats, OffloadingConnectorStats)
assert offload_connector_stats.data["CPU_to_GPU"] == [
{"op_size": 16, "op_time": 1.0},
{"op_size": 8, "op_time": 0.5},
]
assert offload_connector_stats.data["GPU_to_CPU"] == [
{"op_size": 1, "op_time": 0.1},
{"op_size": 2, "op_time": 0.2},
]
def test_aggregate_same_connector(self):
"""Test aggregating stats from the same connector type."""
stats1 = OffloadingConnectorStats(
data={
"CPU_to_GPU": [
{"op_size": 16, "op_time": 1.0},
{"op_size": 8, "op_time": 0.5},
],
"GPU_to_CPU": [
{"op_size": 1, "op_time": 0.1},
{"op_size": 2, "op_time": 0.2},
],
}
)
stats2 = OffloadingConnectorStats(
data={
"CPU_to_GPU": [
{"op_size": 3, "op_time": 0.2},
{"op_size": 7, "op_time": 0.9},
],
"GPU_to_CPU": [{"op_size": 16, "op_time": 2}],
}
)
result = stats1.aggregate(stats2)
assert result is stats1 # Should return self
offload_connector_stats = result
assert offload_connector_stats.data["CPU_to_GPU"] == [
{"op_size": 16, "op_time": 1.0},
{"op_size": 8, "op_time": 0.5},
{"op_size": 3, "op_time": 0.2},
{"op_size": 7, "op_time": 0.9},
]
assert offload_connector_stats.data["GPU_to_CPU"] == [
{"op_size": 1, "op_time": 0.1},
{"op_size": 2, "op_time": 0.2},
{"op_size": 16, "op_time": 2},
]
def test_reduce(self):
"""Test that reduce() correctly reduces all nested connector stats."""
stats = OffloadingConnectorStats(
data={
"CPU_to_GPU": [
{"op_size": 16, "op_time": 1.0},
{"op_size": 8, "op_time": 0.5},
{"op_size": 3, "op_time": 0.2},
{"op_size": 7, "op_time": 0.9},
],
"GPU_to_CPU": [
{"op_size": 1, "op_time": 0.1},
{"op_size": 2, "op_time": 0.2},
{"op_size": 16, "op_time": 2},
],
}
)
reduced = stats.reduce()
assert isinstance(reduced, dict)
# Check that the stats were reduced (should have aggregated values)
assert "CPU_to_GPU_total_bytes" in reduced
assert "CPU_to_GPU_total_time" in reduced
assert "GPU_to_CPU_total_bytes" in reduced
assert "GPU_to_CPU_total_time" in reduced
assert reduced["CPU_to_GPU_total_bytes"] == 34
assert reduced["CPU_to_GPU_total_time"] == 2.6
assert reduced["GPU_to_CPU_total_time"] == 2.3
assert reduced["GPU_to_CPU_total_bytes"] == 19
def test_reset(self):
"""Test that reset() resets all nested connector stats."""
offload_connector_stats = OffloadingConnectorStats(
data={
"CPU_to_GPU": [
{"op_size": 3, "op_time": 0.2},
{"op_size": 7, "op_time": 0.9},
],
"GPU_to_CPU": [{"op_size": 16, "op_time": 2}],
}
)
assert not offload_connector_stats.is_empty()
offload_connector_stats.reset()
# After reset, stats should be empty
assert offload_connector_stats.is_empty()
assert len(offload_connector_stats.data) == 0
...@@ -168,15 +168,30 @@ def test_transfer( ...@@ -168,15 +168,30 @@ def test_transfer(
orig_dst_caches = [x.clone() for x in handler.dst_tensors] orig_dst_caches = [x.clone() for x in handler.dst_tensors]
# call transfer function # call transfer function
start_time = time.time()
assert handler.transfer_async(1, (src_spec, dst_spec)) assert handler.transfer_async(1, (src_spec, dst_spec))
assert set({x[0] for x in handler._transfers}) == {1} assert set({x.job_id for x in handler._transfers}) == {1}
# wait for transfer to complete # wait for transfer to complete
end_time = time.time() + 10 end_time = time.time() + 10
while time.time() < end_time: while time.time() < end_time:
finished = handler.get_finished() finished = handler.get_finished()
if finished: if finished:
assert finished == [(1, True)] assert finished[0].job_id == 1
assert finished[0].success
assert (
finished[0].transfer_type == ("GPU", "CPU")
if gpu_to_cpu
else ("CPU", "GPU")
)
assert (
finished[0].transfer_size
== handler.total_block_size_in_bytes
* handler.dst_block_size_factor
* len(dst_blocks)
)
assert finished[0].transfer_time > 0
assert finished[0].transfer_time < (time.time() - start_time)
break break
time.sleep(0.1) time.sleep(0.1)
......
...@@ -124,7 +124,7 @@ class KVConnectorPromMetrics: ...@@ -124,7 +124,7 @@ class KVConnectorPromMetrics:
self._counter_cls = metric_types[Counter] self._counter_cls = metric_types[Counter]
self._histogram_cls = metric_types[Histogram] self._histogram_cls = metric_types[Histogram]
self._labelnames = labelnames self._labelnames = labelnames
self._per_engine_labelvalues = per_engine_labelvalues self.per_engine_labelvalues = per_engine_labelvalues
def make_per_engine(self, metric: PromMetric) -> dict[int, PromMetric]: def make_per_engine(self, metric: PromMetric) -> dict[int, PromMetric]:
""" """
...@@ -134,7 +134,7 @@ class KVConnectorPromMetrics: ...@@ -134,7 +134,7 @@ class KVConnectorPromMetrics:
""" """
return { return {
idx: metric.labels(*labelvalues) idx: metric.labels(*labelvalues)
for idx, labelvalues in self._per_engine_labelvalues.items() for idx, labelvalues in self.per_engine_labelvalues.items()
} }
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0): def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
......
...@@ -17,6 +17,12 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import ( ...@@ -17,6 +17,12 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorRole, KVConnectorRole,
) )
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
PromMetric,
PromMetricT,
)
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
...@@ -28,7 +34,11 @@ from vllm.v1.kv_offload.abstract import OffloadingManager ...@@ -28,7 +34,11 @@ from vllm.v1.kv_offload.abstract import OffloadingManager
from vllm.v1.kv_offload.factory import OffloadingSpecFactory from vllm.v1.kv_offload.factory import OffloadingSpecFactory
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
from vllm.v1.kv_offload.spec import OffloadingSpec from vllm.v1.kv_offload.spec import OffloadingSpec
from vllm.v1.kv_offload.worker.worker import OffloadingWorker, TransferSpec from vllm.v1.kv_offload.worker.worker import (
OffloadingWorker,
TransferSpec,
TransferType,
)
from vllm.v1.outputs import KVConnectorOutput from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.request import Request from vllm.v1.request import Request
...@@ -37,6 +47,66 @@ ReqId = str ...@@ -37,6 +47,66 @@ ReqId = str
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclass
class OffloadingOperationMetrics:
op_size: int
op_time: float
@dataclass
class OffloadingConnectorStats(KVConnectorStats):
def __post_init__(self):
if not self.data:
# Empty container init, no data is passed in.
self.reset()
def reset(self):
self.data: dict[str, list[OffloadingOperationMetrics]] = {}
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
if not other.is_empty():
for k, v in other.data.items():
if k not in self.data:
self.data[k] = v
else:
accumulator = self.data[k]
assert isinstance(accumulator, list)
accumulator.extend(v)
return self
def reduce(self) -> dict[str, int | float]:
"""
Reduce the observations collected during a time interval to one or
more representative values (eg avg/median/sum of the series).
This is meant to be called by the logger to produce a summary of the
stats for the last time interval.
"""
return_dict: dict[str, int | float] = {}
for transfer_type, ops_list in self.data.items():
assert isinstance(ops_list, list)
total_bytes = 0
total_time = 0
for op in ops_list:
assert isinstance(op, dict)
total_bytes += op["op_size"]
total_time += op["op_time"]
return_dict[f"{transfer_type}_total_bytes"] = total_bytes
return_dict[f"{transfer_type}_total_time"] = total_time
return return_dict
def is_empty(self) -> bool:
return not self.data
def record_transfer(self, num_bytes: int, time: float, transfer_type: TransferType):
src, dst = transfer_type
transfer_type_key = src + "_to_" + dst
op = OffloadingOperationMetrics(num_bytes, time)
if transfer_type_key in self.data:
self.data[transfer_type_key].append(op)
else:
self.data[transfer_type_key] = [op]
@dataclass @dataclass
class OffloadingConnectorMetadata(KVConnectorMetadata): class OffloadingConnectorMetadata(KVConnectorMetadata):
reqs_to_load: dict[ReqId, TransferSpec] reqs_to_load: dict[ReqId, TransferSpec]
...@@ -143,6 +213,33 @@ class OffloadingConnector(KVConnectorBase_V1): ...@@ -143,6 +213,33 @@ class OffloadingConnector(KVConnectorBase_V1):
assert self.connector_scheduler is not None assert self.connector_scheduler is not None
return self.connector_scheduler.take_events() return self.connector_scheduler.take_events()
def get_kv_connector_stats(self) -> KVConnectorStats | None:
if self.connector_worker is None:
return None # We only emit stats from the worker-side
return self.connector_worker.get_kv_connector_stats()
@classmethod
def build_kv_connector_stats(
cls, data: dict[str, Any] | None = None
) -> KVConnectorStats | None:
return (
OffloadingConnectorStats(data=data)
if data is not None
else OffloadingConnectorStats()
)
@classmethod
def build_prom_metrics(
cls,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[object]],
) -> KVConnectorPromMetrics:
return OffloadPromMetrics(
vllm_config, metric_types, labelnames, per_engine_labelvalues
)
class OffloadingConnectorScheduler: class OffloadingConnectorScheduler:
"""Implementation of Scheduler side methods""" """Implementation of Scheduler side methods"""
...@@ -467,7 +564,8 @@ class OffloadingConnectorWorker: ...@@ -467,7 +564,8 @@ class OffloadingConnectorWorker:
self._job_counter = 0 self._job_counter = 0
# job_id -> (req_id, store) self.kv_connector_stats = OffloadingConnectorStats()
# req_id -> (job_id, store)
self._jobs: dict[int, tuple[ReqId, bool]] = {} self._jobs: dict[int, tuple[ReqId, bool]] = {}
# req_id -> active job IDs # req_id -> active job IDs
self._load_job: dict[ReqId, int] = {} self._load_job: dict[ReqId, int] = {}
...@@ -559,10 +657,21 @@ class OffloadingConnectorWorker: ...@@ -559,10 +657,21 @@ class OffloadingConnectorWorker:
""" """
finished_sending = set() finished_sending = set()
finished_recving = set() finished_recving = set()
for job_id, success in self.worker.get_finished(): for transfer_result in self.worker.get_finished():
# we currently do not support job failures # we currently do not support job failures
assert success job_id = transfer_result.job_id
assert transfer_result.success
req_id, store = self._jobs.pop(job_id) req_id, store = self._jobs.pop(job_id)
if (
transfer_result.transfer_time
and transfer_result.transfer_size is not None
and transfer_result.transfer_type is not None
):
self.kv_connector_stats.record_transfer(
num_bytes=transfer_result.transfer_size,
time=transfer_result.transfer_time,
transfer_type=transfer_result.transfer_type,
)
if store: if store:
req_jobs = self._store_jobs[req_id] req_jobs = self._store_jobs[req_id]
req_jobs.remove(job_id) req_jobs.remove(job_id)
...@@ -588,3 +697,104 @@ class OffloadingConnectorWorker: ...@@ -588,3 +697,104 @@ class OffloadingConnectorWorker:
del self._store_jobs[req_id] del self._store_jobs[req_id]
return finished_sending, finished_recving return finished_sending, finished_recving
def get_kv_connector_stats(self) -> KVConnectorStats | None:
"""
Get the KV transfer stats for the connector.
"""
if self.kv_connector_stats.is_empty():
return None
# Clear stats for next iteration
kv_connector_stats = self.kv_connector_stats
self.kv_connector_stats = OffloadingConnectorStats()
return kv_connector_stats
class OffloadPromMetrics(KVConnectorPromMetrics):
def __init__(
self,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[object]],
):
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
# (engine_idx, transfer_tupe) -> (metric with bounded labels)
self.histogram_transfer_size: dict[tuple[int, str], PromMetricT] = {}
self.counter_kv_bytes: dict[tuple[int, str], PromMetricT] = {}
self.counter_kv_transfer_time: dict[tuple[int, str], PromMetricT] = {}
buckets = [ # In bytes
1e6,
5e6,
10e6,
20e6,
40e6,
60e6,
80e6,
100e6,
150e6,
200e6,
]
self._counter_kv_bytes = self._counter_cls(
name="vllm:kv_offload_total_bytes",
documentation="Number of bytes offloaded by KV connector",
labelnames=labelnames + ["transfer_type"],
)
self._counter_kv_transfer_time = self._counter_cls(
name="vllm:kv_offload_total_time",
documentation="Total time measured by all KV offloading operations",
labelnames=labelnames + ["transfer_type"],
)
self._histogram_transfer_size = self._histogram_cls(
name="vllm:kv_offload_size",
documentation="Histogram of KV offload transfer size, in bytes.",
buckets=buckets[:],
labelnames=labelnames + ["transfer_type"],
)
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
"""
Observe transfer statistics from the new data structure.
transfer_stats_data is expected to be a dict where:
- keys are transfer type strings (e.g., "cpu_to_gpu", "gpu_to_cpu")
- values are lists of OffloadingOperationMetrics objects
"""
for transfer_type, ops in transfer_stats_data.items():
# Cache:
if (engine_idx, transfer_type) not in self.histogram_transfer_size:
self.histogram_transfer_size[(engine_idx, transfer_type)] = (
self._histogram_transfer_size.labels(
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
)
)
self.counter_kv_bytes[(engine_idx, transfer_type)] = (
self._counter_kv_bytes.labels(
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
)
)
self.counter_kv_transfer_time[(engine_idx, transfer_type)] = (
self._counter_kv_transfer_time.labels(
*(self.per_engine_labelvalues[engine_idx] + [transfer_type])
)
)
# Process ops:
assert isinstance(ops, list)
for op in ops: # ops is a list of serialized OffloadingOperationMetrics
assert isinstance(op, dict)
# Observe size histogram
self.histogram_transfer_size[(engine_idx, transfer_type)].observe(
op["op_size"]
)
# Increment byte and time counters
self.counter_kv_bytes[(engine_idx, transfer_type)].inc(op["op_size"])
self.counter_kv_transfer_time[(engine_idx, transfer_type)].inc(
op["op_time"]
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import deque from collections import deque
from dataclasses import dataclass
import numpy as np import numpy as np
import torch import torch
...@@ -19,6 +20,15 @@ from vllm.v1.kv_offload.worker.worker import ( ...@@ -19,6 +20,15 @@ from vllm.v1.kv_offload.worker.worker import (
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclass
class Transfer:
job_id: int
stream: torch.cuda.Stream
start_event: torch.Event
end_event: torch.Event
num_bytes: int
def expand_block_ids( def expand_block_ids(
block_ids: np.ndarray, block_ids: np.ndarray,
block_size_factor: int, block_size_factor: int,
...@@ -92,14 +102,15 @@ class SingleDirectionOffloadingHandler(OffloadingHandler): ...@@ -92,14 +102,15 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
tensor.element_size() * tensor.stride(0) * min_block_size_factor tensor.element_size() * tensor.stride(0) * min_block_size_factor
for tensor in src_tensors for tensor in src_tensors
] ]
self.total_block_size_in_bytes = sum(self.block_size_in_bytes)
assert len(src_tensors) > 0 assert len(src_tensors) > 0
self.gpu_to_cpu: bool = self.src_tensors[0].is_cuda self.gpu_to_cpu: bool = self.src_tensors[0].is_cuda
self.transfer_type = ("GPU", "CPU") if self.gpu_to_cpu else ("CPU", "GPU")
# job_id -> event # job_id -> event
self._transfer_events: dict[int, torch.Event] = {} self._transfer_events: dict[int, torch.Event] = {}
# queue of transfers (job_id, stream, event) # queue of transfers (job_id, stream, event)
self._transfers: deque[tuple[int, torch.cuda.Stream, torch.Event]] = deque() self._transfers: deque[Transfer] = deque()
# list of CUDA streams available for re-use # list of CUDA streams available for re-use
self._stream_pool: list[torch.cuda.Stream] = [] self._stream_pool: list[torch.cuda.Stream] = []
# list of CUDA events available for re-use # list of CUDA events available for re-use
...@@ -132,16 +143,27 @@ class SingleDirectionOffloadingHandler(OffloadingHandler): ...@@ -132,16 +143,27 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
src_to_dst_tensor = torch.from_numpy(src_to_dst) src_to_dst_tensor = torch.from_numpy(src_to_dst)
stream = self._stream_pool.pop() if self._stream_pool else torch.cuda.Stream() stream = self._stream_pool.pop() if self._stream_pool else torch.cuda.Stream()
event = self._event_pool.pop() if self._event_pool else torch.Event() start_event = (
self._event_pool.pop()
if self._event_pool
else torch.Event(enable_timing=True)
)
end_event = (
self._event_pool.pop()
if self._event_pool
else torch.Event(enable_timing=True)
)
if self.gpu_to_cpu: if self.gpu_to_cpu:
# wait for model computation to finish before offloading # wait for model computation to finish before offloading
stream.wait_stream(torch.cuda.current_stream()) stream.wait_stream(torch.cuda.current_stream())
if self._transfers: if self._transfers:
_, _, last_event = self._transfers[-1] last_transfer: Transfer = self._transfers[-1]
last_event = last_transfer.end_event
# assure job will start only after the previous one completes # assure job will start only after the previous one completes
stream.wait_event(last_event) stream.wait_event(last_event)
with torch.cuda.stream(stream): with torch.cuda.stream(stream):
start_event.record(stream)
for src_tensor, dst_tensor, block_size_in_bytes in zip( for src_tensor, dst_tensor, block_size_in_bytes in zip(
self.src_tensors, self.src_tensors,
self.dst_tensors, self.dst_tensors,
...@@ -153,22 +175,42 @@ class SingleDirectionOffloadingHandler(OffloadingHandler): ...@@ -153,22 +175,42 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
block_size_in_bytes, block_size_in_bytes,
src_to_dst_tensor, src_to_dst_tensor,
) )
event.record(stream) end_event.record(stream)
self._transfer_events[job_id] = event self._transfer_events[job_id] = end_event
self._transfers.append((job_id, stream, event)) self._transfers.append(
Transfer(
job_id=job_id,
stream=stream,
start_event=start_event,
end_event=end_event,
num_bytes=dst_sub_block_count * self.total_block_size_in_bytes,
)
)
# success # success
return True return True
def get_finished(self) -> list[TransferResult]: def get_finished(self) -> list[TransferResult]:
results: list[TransferResult] = [] results: list[TransferResult] = []
while self._transfers and self._transfers[0][2].query(): while self._transfers and self._transfers[0].end_event.query():
job_id, stream, event = self._transfers.popleft() transfer = self._transfers.popleft()
results.append((job_id, True)) transfer_time = (
self._stream_pool.append(stream) transfer.start_event.elapsed_time(transfer.end_event) * 1e-3
self._event_pool.append(event) ) # elapsed_time is in miliseconds
del self._transfer_events[job_id] result = TransferResult(
job_id=transfer.job_id,
success=True,
transfer_size=transfer.num_bytes,
transfer_time=transfer_time,
transfer_type=self.transfer_type,
)
results.append(result)
self._stream_pool.append(transfer.stream)
self._event_pool.append(transfer.end_event)
self._event_pool.append(transfer.start_event)
del self._transfer_events[transfer.job_id]
return results return results
def wait(self, job_ids: set[int]): def wait(self, job_ids: set[int]):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.kv_offload.abstract import LoadStoreSpec from vllm.v1.kv_offload.abstract import LoadStoreSpec
...@@ -9,12 +10,19 @@ from vllm.v1.kv_offload.abstract import LoadStoreSpec ...@@ -9,12 +10,19 @@ from vllm.v1.kv_offload.abstract import LoadStoreSpec
TransferSpec = tuple[LoadStoreSpec, LoadStoreSpec] TransferSpec = tuple[LoadStoreSpec, LoadStoreSpec]
# transfers are forwarded to workers by (src_medium, dst_medium) # transfers are forwarded to workers by (src_medium, dst_medium)
TransferType = tuple[str, str] TransferType = tuple[str, str]
# transfer result (job_id, success)
TransferResult = tuple[int, bool]
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclass
class TransferResult:
job_id: int
success: bool
transfer_size: int | None = None # Size in bytes
transfer_time: float | None = None
transfer_type: TransferType | None = None
class OffloadingHandler(ABC): class OffloadingHandler(ABC):
""" """
OffloadingHandler class for managing asynchronous KV data transfers OffloadingHandler class for managing asynchronous KV data transfers
...@@ -57,7 +65,6 @@ class OffloadingHandler(ABC): ...@@ -57,7 +65,6 @@ class OffloadingHandler(ABC):
def wait(self, job_ids: set[int]) -> None: def wait(self, job_ids: set[int]) -> None:
""" """
Wait for jobs to finish (blocking). Wait for jobs to finish (blocking).
Args: Args:
job_ids: The set of job IDs to wait for. job_ids: The set of job IDs to wait for.
""" """
...@@ -120,7 +127,6 @@ class OffloadingWorker: ...@@ -120,7 +127,6 @@ class OffloadingWorker:
transfer_type = (src.medium(), dst.medium()) transfer_type = (src.medium(), dst.medium())
handler = self.transfer_type_to_handler.get(transfer_type) handler = self.transfer_type_to_handler.get(transfer_type)
assert handler is not None assert handler is not None
try: try:
success = handler.transfer_async(job_id, spec) success = handler.transfer_async(job_id, spec)
except Exception as e: except Exception as e:
...@@ -137,7 +143,6 @@ class OffloadingWorker: ...@@ -137,7 +143,6 @@ class OffloadingWorker:
logger.warning("Failed to submit %r transfer %d", transfer_type, job_id) logger.warning("Failed to submit %r transfer %d", transfer_type, job_id)
else: else:
logger.debug("Submitted %r transfer %d: %r", transfer_type, job_id, spec) logger.debug("Submitted %r transfer %d: %r", transfer_type, job_id, spec)
return success return success
def get_finished(self) -> list[TransferResult]: def get_finished(self) -> list[TransferResult]:
...@@ -145,7 +150,7 @@ class OffloadingWorker: ...@@ -145,7 +150,7 @@ class OffloadingWorker:
Get transfers finished since last call. Get transfers finished since last call.
Returns: Returns:
A list of (job_id, success) of transfers. A list of TransferResults
""" """
finished = [] finished = []
for handler in self.handlers: for handler in self.handlers:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment