Commit dcecc47d authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

test: add tests for kv bindings (#35)

parent 6705d483
......@@ -106,7 +106,7 @@ ENV VLLM_GENERATE_WORKERS=${VLLM_FRAMEWORK:+1}
ENV VLLM_BASELINE_TP_SIZE=${VLLM_FRAMEWORK:+1}
ENV VLLM_CONTEXT_TP_SIZE=${VLLM_FRAMEWORK:+1}
ENV VLLM_GENERATE_TP_SIZE=${VLLM_FRAMEWORK:+1}
ENV VLLM_KV_CAPI_PATH="/opt/dynemo/llm_binding/lib/libdynemo_llm_capi.so"
ENV VLLM_KV_CAPI_PATH="/opt/dynemo/bindings/lib/libdynemo_llm_capi.so"
ENV PYTHONUNBUFFERED=1
# Install NATS - pointing toward NATS github instead of binaries.nats.dev due to server instability
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import uvloop
from common.protocol import Request, Response
from vllm.logger import logger as vllm_logger
from dynemo.llm import KvRouter
from dynemo.runtime import DistributedRuntime, dynemo_endpoint, dynemo_worker
class Router:
"""
Request handler for the generate endpoint
"""
def __init__(
self,
router,
workers_client,
):
self.router = router
self.workers_client = workers_client
@dynemo_endpoint(Request, Response)
async def generate(self, request):
lora_id = 0
worker_id = None
tokens = [3] * 64
try:
worker_id = await self.router.schedule(tokens, lora_id)
# [NOTE][TODO] Now that the scheduler may return more error messages,
# now we are catching all exceptions and logging them. Should have
# catch specific router exceptions once we have dedicated types.
except Exception as e:
vllm_logger.info(f"got exception of type {type(e)}: {e}")
worker_id = None
vllm_logger.exception(f"Error during worker selection: {e}")
vllm_logger.info(f"Scheduling to worker_id: {worker_id}")
if worker_id is None:
vllm_logger.info("randomly select worker")
engine_generator = await self.workers_client.random(
request.model_dump_json()
)
else:
vllm_logger.info(f"directly select worker: {worker_id}")
engine_generator = await self.workers_client.direct(
request.model_dump_json(), worker_id
)
async for resp in engine_generator:
resp = resp.data() if hasattr(resp, "data") else resp
yield resp
@dynemo_endpoint(Request, Response)
async def mock_generate(self, request):
print(f"Received request: {request}")
yield "Hello, World!"
ROUTE_SELF = True
@dynemo_worker()
async def worker(runtime: DistributedRuntime):
workers_client = (
await runtime.namespace("dynemo")
.component("vllm")
.endpoint("generate")
.client()
)
vllm_logger.info(
f"Have number of workers ({len(workers_client.endpoint_ids())}) are ready:\n"
+ "\n".join(f"id: {id}" for id in workers_client.endpoint_ids())
)
# [TODO] Collect endpoint implementation expects services to provide
# ForwardPassMetrics as part of stats handling and it will panic if
# otherwise. This needs to be fixed so that non-providing endpoints will
# simply be ignored, but before that, we will make sure that the services
# of the same namespace::component are created via KvMetricsPublisher,
# if it is also used to create endpoints.
kv_listener = runtime.namespace("dynemo").component("vllm")
await kv_listener.create_service()
router = KvRouter(runtime, kv_listener)
# i.e. below will cause panic
# endpoint = kv_listener.endpoint("generate")
# await endpoint.serve_endpoint(
# Router(router, workers_client).mock_generate
# )
router_component = runtime.namespace("dynemo").component("frontend")
await router_component.create_service()
endpoint = router_component.endpoint("generate")
await endpoint.serve_endpoint(Router(router, workers_client).generate)
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
......@@ -16,39 +16,97 @@
import asyncio
import ctypes
import os
import subprocess
from ctypes import c_char_p, c_int64, c_uint32
from time import sleep
from typing import List
import uvloop
from common.protocol import Request, Response
from vllm.logger import logger as vllm_logger
import pytest
from dynemo.llm import KvMetricsPublisher
from dynemo.runtime import DistributedRuntime, dynemo_endpoint, dynemo_worker
from dynemo.llm import KvIndexer, KvMetricsAggregator, KvMetricsPublisher
from dynemo.runtime import DistributedRuntime
pytestmark = pytest.mark.pre_merge
runtime = None
@pytest.fixture(scope="module", autouse=True)
def setup_and_teardown():
# Setup code
nats_server = subprocess.Popen(["nats-server", "-js"])
etcd = subprocess.Popen(["etcd"])
print("Setting up resources")
sleep(5) # wait for nats-server and etcd to start
yield
# Teardown code
print("Tearing down resources")
nats_server.terminate()
nats_server.wait()
etcd.terminate()
etcd.wait()
async def test_event_handler():
global runtime
if runtime is None:
loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop)
namespace = "kv_test"
component = "event"
# publisher
worker_id = 233
event_publisher = EventPublisher(namespace, component, worker_id)
# indexer
kv_listener = runtime.namespace(namespace).component(component)
await kv_listener.create_service()
indexer = KvIndexer(kv_listener)
test_token = [3] * 64
lora_id = 0 # lora_id is not used in the indexer
scores = await indexer.find_matches_for_request(test_token, lora_id)
assert not scores.scores
event_publisher.store_event(test_token, lora_id)
# wait for the event to be processed as it is sent asynchronously
await asyncio.sleep(1)
scores = await indexer.find_matches_for_request(test_token, lora_id)
assert scores.scores
assert worker_id in scores.scores
assert scores.scores[worker_id] == 1
# remove event
event_publisher.remove_event()
await asyncio.sleep(1)
scores = await indexer.find_matches_for_request(test_token, lora_id)
assert not scores.scores
# KV events
class DynemoResult:
OK = 0
ERR = 1
class MockEngine:
"""
Request handler for the generate endpoint
"""
class EventPublisher:
def __init__(self, namespace: str, component: str, worker_id: int):
self.event_id_counter = 0
self.block_ids: List[int] = []
def __init__(self, metrics_publisher, worker_id):
self.worker_id = worker_id
# KV events
self.lib = ctypes.CDLL("/opt/dynemo/llm_binding/lib/libdynemo_llm_capi.so")
# load event publisher library
self.lib = ctypes.CDLL(os.environ["VLLM_KV_CAPI_PATH"])
self.lib.dynemo_llm_init.argtypes = [c_char_p, c_char_p, c_int64]
self.lib.dynemo_llm_init.restype = c_uint32
result = self.lib.dynemo_llm_init("dynemo".encode(), "vllm".encode(), worker_id)
if result == DynemoResult.OK:
vllm_logger.info(
"KVCacheEventManager initialized successfully. Ready to publish KV Cache Events"
)
else:
vllm_logger.info("KVCacheEventManager initialization failed!")
result = self.lib.dynemo_llm_init(
namespace.encode(), component.encode(), worker_id
)
assert result == DynemoResult.OK
self.lib.dynemo_kv_event_publish_stored.argtypes = [
ctypes.c_uint64, # event_id
ctypes.POINTER(ctypes.c_uint32), # token_ids
......@@ -71,41 +129,7 @@ class MockEngine:
ctypes.c_uint32
) # dynemo_llm_result_t
# KV metrics
self.metrics_publisher = metrics_publisher
self.request_active_slots = 0
self.request_total_slots = 4
self.kv_active_block = 0
self.kv_total_blocks = 4
# [NOTE] Now that the component must has proper metrics reported
# to be properly selected by the router
self.metrics_publisher.publish(
self.request_active_slots,
self.request_total_slots,
self.kv_active_block,
self.kv_total_blocks,
)
self.event_id_counter = 0
self.tokens = [3] * 64
@dynemo_endpoint(Request, Response)
async def generate(self, request):
print(f"Received request: {request}")
self.request_active_slots = min(
self.request_active_slots + 1, self.request_total_slots
)
self.kv_active_block = min(self.kv_active_block + 1, self.kv_total_blocks)
self.metrics_publisher.publish(
self.request_active_slots,
self.request_total_slots,
self.kv_active_block,
self.kv_total_blocks,
)
self.store_event()
yield "Hello, World!"
def store_event(self):
def store_event(self, tokens, lora_id):
parent_hash = (
(ctypes.c_uint64 * 1)(self.event_id_counter)
if self.event_id_counter > 0
......@@ -113,57 +137,80 @@ class MockEngine:
)
result = self.lib.dynemo_kv_event_publish_stored(
self.event_id_counter, # uint64_t event_id
(ctypes.c_uint32 * len(self.tokens))(
*self.tokens
), # const uint32_t *token_ids
(ctypes.c_size_t * 1)(
len(self.tokens)
), # const uintptr_t *num_block_tokens
(ctypes.c_uint32 * len(tokens))(*tokens), # const uint32_t *token_ids
(ctypes.c_size_t * 1)(len(tokens)), # const uintptr_t *num_block_tokens
(ctypes.c_uint64 * 1)(self.event_id_counter), # const uint64_t *block_ids
1, # uintptr_t num_blocks
parent_hash, # const uint64_t *parent_hash
0, # uint64_t lora_id
lora_id, # uint64_t lora_id
)
self.block_ids.append(self.event_id_counter)
self.event_id_counter += 1
if result == DynemoResult.OK:
vllm_logger.debug(f"Store - Published KV Event: {self.event_id_counter}")
else:
vllm_logger.debug(
f"Store - Failed to Publish KV Event: {self.event_id_counter}"
)
async def cooldown(self):
while True:
await asyncio.sleep(5)
self.request_active_slots = max(0, self.request_active_slots - 1)
self.kv_active_block = max(0, self.kv_active_block - 1)
self.metrics_publisher.publish(
self.request_active_slots,
self.request_total_slots,
self.kv_active_block,
self.kv_total_blocks,
)
@dynemo_worker()
async def worker(runtime: DistributedRuntime):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
component = runtime.namespace("dynemo").component("vllm")
metrics_publisher = KvMetricsPublisher()
await metrics_publisher.create_service(component)
endpoint = component.endpoint("generate")
engine = MockEngine(metrics_publisher, endpoint.lease_id())
await asyncio.gather(
engine.cooldown(),
endpoint.serve_endpoint(engine.generate),
)
assert result == DynemoResult.OK
def remove_event(self):
result = self.lib.dynemo_kv_event_publish_removed(
self.event_id_counter, # uint64_t event_id
(ctypes.c_uint64 * 1)(self.block_ids[-1]), # const uint64_t *block_ids
1, # uintptr_t num_blocks
)
self.event_id_counter += 1
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
assert result == DynemoResult.OK
async def test_metrics_aggregator():
global runtime
if runtime is None:
loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop)
namespace = "kv_test"
component = "metrics"
kv_listener = runtime.namespace(namespace).component(component)
await kv_listener.create_service()
# aggregator
metrics_aggregator = KvMetricsAggregator(kv_listener)
# has nothing to aggregate as worker has not started
metrics = await metrics_aggregator.get_metrics()
assert not metrics.endpoints
expected_metrics = {
"request_active_slots": 0,
"request_total_slots": 1024,
"kv_active_blocks": 523,
"kv_total_blocks": 777,
}
# need 'create_taskk' to put publisher task in the background
asyncio.create_task(metrics_publisher(kv_listener, expected_metrics))
# needs time for publisher to spawn up
for i in range(10):
await asyncio.sleep(1)
metrics = await metrics_aggregator.get_metrics()
if metrics.endpoints:
break
assert metrics.endpoints
for endpoint in metrics.endpoints:
# [TODO] not really checking id for now, can't get it as create_endpoint()
# create and serve the endpoint internally
assert endpoint.worker_id != 0
assert endpoint.request_active_slots == expected_metrics["request_active_slots"]
assert endpoint.request_total_slots == expected_metrics["request_total_slots"]
assert endpoint.kv_active_blocks == expected_metrics["kv_active_blocks"]
assert endpoint.kv_total_blocks == expected_metrics["kv_total_blocks"]
async def metrics_publisher(kv_listener, expected_metrics):
metrics_publisher = KvMetricsPublisher()
metrics_publisher.publish(
expected_metrics["request_active_slots"],
expected_metrics["request_total_slots"],
expected_metrics["kv_active_blocks"],
expected_metrics["kv_total_blocks"],
)
await metrics_publisher.create_endpoint(kv_listener)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment