"vllm/vscode:/vscode.git/clone" did not exist on "15436806912d7ad9371c8bcf6a46857590c107d2"
Unverified Commit 111faf11 authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

[Core] Scheduler: Publish connector events after output (#25875)


Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
parent 6afc28a9
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import socket
import time import time
import msgspec
import msgspec.msgpack
import pytest import pytest
import zmq
from tqdm import tqdm
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams, TokensPrompt
from vllm.config import KVTransferConfig from vllm.config import KVEventsConfig, KVTransferConfig
from vllm.distributed.kv_events import BlockStored, KVEventBatch
CPU_BLOCK_SIZES = [16, 48] CPU_BLOCK_SIZES = [16, 48]
class MockSubscriber:
"""Helper class to receive and verify published events"""
def __init__(
self,
endpoint: str,
topic: str,
):
self.ctx = zmq.Context.instance()
self.topic_bytes = topic.encode("utf-8")
# Set up subscriber socket
self.sub = self.ctx.socket(zmq.SUB)
self.sub.setsockopt(zmq.SUBSCRIBE, self.topic_bytes)
self.sub.connect(endpoint)
self.decoder = msgspec.msgpack.Decoder(type=KVEventBatch)
def get_new_cpu_stored_events(self) -> list[BlockStored]:
cpu_stored_events: list[BlockStored] = []
poller = zmq.Poller()
poller.register(self.sub, zmq.POLLIN)
timeout = 1000 # 1 second
while True:
events = dict(poller.poll(timeout))
if events.get(self.sub) != zmq.POLLIN:
return cpu_stored_events
topic_bytes, _, payload = self.sub.recv_multipart()
assert topic_bytes == self.topic_bytes
event_batch = self.decoder.decode(payload)
assert isinstance(event_batch, KVEventBatch)
for event in event_batch.events:
if isinstance(event, BlockStored) and event.medium == "CPU":
cpu_stored_events.append(event)
timeout = 100
def close(self):
"""Clean up resources"""
self.sub.close()
@pytest.mark.parametrize("cpu_block_size", CPU_BLOCK_SIZES) @pytest.mark.parametrize("cpu_block_size", CPU_BLOCK_SIZES)
def test_cpu_offloading(cpu_block_size: int) -> None: def test_cpu_offloading(cpu_block_size: int) -> None:
""" """
...@@ -20,41 +73,80 @@ def test_cpu_offloading(cpu_block_size: int) -> None: ...@@ -20,41 +73,80 @@ def test_cpu_offloading(cpu_block_size: int) -> None:
kv_transfer_config = KVTransferConfig( kv_transfer_config = KVTransferConfig(
kv_connector="OffloadingConnector", kv_connector="OffloadingConnector",
kv_role="kv_both", kv_role="kv_both",
kv_connector_extra_config={"num_cpu_blocks": 100, "block_size": cpu_block_size}, kv_connector_extra_config={
"num_cpu_blocks": 1000,
"block_size": cpu_block_size,
},
)
port: int
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("0.0.0.0", 0))
port = s.getsockname()[1]
events_endpoint = f"tcp://*:{port}"
kv_events_config = KVEventsConfig(
enable_kv_cache_events=True,
publisher="zmq",
endpoint=events_endpoint,
topic="test",
) )
llm = LLM( llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct", model="meta-llama/Llama-3.2-1B-Instruct",
gpu_memory_utilization=0.5, gpu_memory_utilization=0.5,
kv_events_config=kv_events_config,
kv_transfer_config=kv_transfer_config, kv_transfer_config=kv_transfer_config,
disable_hybrid_kv_cache_manager=True,
) )
prompts = ["Hi " * 100] sampling_params = SamplingParams(temperature=0, max_tokens=1)
sampling_params = SamplingParams(temperature=0, max_tokens=20)
events_endpoint = events_endpoint.replace("*", "127.0.0.1")
subscriber = MockSubscriber(events_endpoint, topic=kv_events_config.topic)
try:
num_times_cpu_better_than_cold = 0
num_tests = 10
total_cold_time = 0.0
total_gpu_hit_time = 0.0
total_cpu_hit_time = 0.0
prompt_token_ids = [0] * 10001
for i in tqdm(range(num_tests), desc="Running tests"):
prompt_token_ids[0] = i
prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)]
# run generation - this should trigger saving KV cache # run generation - this should trigger saving KV cache
start_time = time.time() start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False) llm.generate(prompts, sampling_params, use_tqdm=False)
cold_time = time.time() - start_time cold_time = time.time() - start_time
total_cold_time += cold_time
# run generation again - should hit the GPU prefix cache # run generation again - should hit the GPU prefix cache
start_time = time.time() start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False) llm.generate(prompts, sampling_params, use_tqdm=False)
gpu_hit_time = time.time() - start_time gpu_hit_time = time.time() - start_time
total_gpu_hit_time += gpu_hit_time
# reset prefix cache to avoid GPU hit. # reset prefix cache to avoid GPU hit.
llm.reset_prefix_cache() llm.reset_prefix_cache()
# sleep for a sec to make sure CPU finished storing assert subscriber.get_new_cpu_stored_events()
time.sleep(1)
# run generation again - this should trigger loading from CPU # run generation again - this should trigger loading from CPU
start_time = time.time() start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False) llm.generate(prompts, sampling_params, use_tqdm=False)
cpu_hit_time = time.time() - start_time cpu_hit_time = time.time() - start_time
total_cpu_hit_time += cpu_hit_time
if cpu_hit_time < cold_time:
num_times_cpu_better_than_cold += 1
print("Average times:")
print(f" Cold: {total_cold_time * 1000 / num_tests:.2f}ms")
print(f" GPU hit: {total_gpu_hit_time * 1000 / num_tests:.2f}ms")
print(f" CPU hit: {total_cpu_hit_time * 1000 / num_tests:.2f}ms")
print("Generation times:") assert num_times_cpu_better_than_cold >= 0.8 * num_tests
print(f" Cold: {cold_time * 1000:.2f}ms") finally:
print(f" GPU hit: {gpu_hit_time * 1000:.2f}ms") subscriber.close()
print(f" CPU hit: {cpu_hit_time * 1000:.2f}ms") del llm
...@@ -646,23 +646,6 @@ class Scheduler(SchedulerInterface): ...@@ -646,23 +646,6 @@ class Scheduler(SchedulerInterface):
meta = self.connector.build_connector_meta(scheduler_output) meta = self.connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta scheduler_output.kv_connector_metadata = meta
# collect KV cache events from KV cache manager
events = self.kv_cache_manager.take_events()
# collect KV cache events from connector
if self.connector is not None:
connector_events = self.connector.take_events()
if connector_events:
if events is None:
events = list(connector_events)
else:
events.extend(connector_events)
# publish collected KV cache events
if events:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
self._update_after_schedule(scheduler_output) self._update_after_schedule(scheduler_output)
return scheduler_output return scheduler_output
...@@ -1057,6 +1040,23 @@ class Scheduler(SchedulerInterface): ...@@ -1057,6 +1040,23 @@ class Scheduler(SchedulerInterface):
if kv_connector_output: if kv_connector_output:
self._update_from_kv_xfer_finished(kv_connector_output) self._update_from_kv_xfer_finished(kv_connector_output)
# collect KV cache events from KV cache manager
events = self.kv_cache_manager.take_events()
# collect KV cache events from connector
if self.connector is not None:
connector_events = self.connector.take_events()
if connector_events:
if events is None:
events = list(connector_events)
else:
events.extend(connector_events)
# publish collected KV cache events
if events:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
# Create EngineCoreOutputs for all clients that have requests with # Create EngineCoreOutputs for all clients that have requests with
# outputs in this step. # outputs in this step.
engine_core_outputs = { engine_core_outputs = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment