Commit 3a5fe17d authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files
parent 8fb1421e
...@@ -22,6 +22,8 @@ from vllm.distributed.device_communicators.nixl import NixlMetadata ...@@ -22,6 +22,8 @@ from vllm.distributed.device_communicators.nixl import NixlMetadata
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
from triton_distributed.runtime import DistributedRuntime
METADATA_DIR = "/tmp/nixl" METADATA_DIR = "/tmp/nixl"
...@@ -63,3 +65,55 @@ def find_remote_metadata(engine_id): ...@@ -63,3 +65,55 @@ def find_remote_metadata(engine_id):
msgspec.msgpack.decode(f.read(), type=NixlMetadata) msgspec.msgpack.decode(f.read(), type=NixlMetadata)
) )
return remote_metadata return remote_metadata
class NixlMetadataStore:
NIXL_METADATA_KEY = "nixl_metadata"
def __init__(self, namespace: str, runtime: DistributedRuntime) -> None:
self._namespace = namespace
# TODO Remove metadata from etcd on delete
self._stored: set[str] = set()
self._cached: dict[str, NixlMetadata] = {}
self._client = runtime.etcd_client()
self._key_prefix = f"{self._namespace}/{NixlMetadataStore.NIXL_METADATA_KEY}"
async def put(self, engine_id, metadata: NixlMetadata):
serialized_metadata = msgspec.msgpack.encode(metadata)
key = "/".join([self._key_prefix, engine_id])
await self._client.kv_put(key, serialized_metadata, None)
self._stored.add(engine_id)
async def get(self, engine_id) -> NixlMetadata:
try:
if engine_id in self._cached:
return self._cached[engine_id]
key = "/".join([self._key_prefix, engine_id])
key_values = await self._client.kv_get_prefix(key)
deserialized_metadata = None
for item in key_values:
deserialized_metadata = msgspec.msgpack.decode(
item["value"], type=NixlMetadata
)
break
if deserialized_metadata is None:
raise Exception("metadata not found in etcd")
self._cached[engine_id] = deserialized_metadata
# TODO watch for changes and update cache
# self._client.add_watch_callback(
# key,
# self._watch_callback,
# )
except Exception as e:
raise Exception("Error retrieving metadata for engine {engine_id}") from e
return deserialized_metadata
...@@ -18,8 +18,7 @@ import asyncio ...@@ -18,8 +18,7 @@ import asyncio
import msgspec import msgspec
import uvloop import uvloop
from common import find_remote_metadata, parse_vllm_args from common import NixlMetadataStore, parse_vllm_args
from vllm.distributed.device_communicators.nixl import NixlMetadata
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.api_server import ( from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args, build_async_engine_client_from_engine_args,
...@@ -31,8 +30,10 @@ from triton_distributed.runtime import DistributedRuntime, triton_worker ...@@ -31,8 +30,10 @@ from triton_distributed.runtime import DistributedRuntime, triton_worker
class RequestHandler: class RequestHandler:
def __init__(self, engine_client): def __init__(self, engine_client, metadata_store):
self.engine_client = engine_client self.engine_client = engine_client
self._metadata_store = metadata_store
self._loaded_metadata = set()
print("RequestHandler initialized") print("RequestHandler initialized")
async def generate(self, raw_request: str): async def generate(self, raw_request: str):
...@@ -50,6 +51,17 @@ class RequestHandler: ...@@ -50,6 +51,17 @@ class RequestHandler:
decode_engine_id=request.engine_id, decode_engine_id=request.engine_id,
) )
# TODO check if metadata has changed
# and reload - currently only loading once
if request.engine_id not in self._loaded_metadata:
remote_metadata = await self._metadata_store.get(request.engine_id)
await self.engine_client.add_remote_nixl_metadata(remote_metadata)
print(
f"Loaded nixl metadata from engine {request.engine_id} into engine {self.engine_client.nixl_metadata.engine_id}"
)
self._loaded_metadata.add(request.engine_id)
async for _ in self.engine_client.generate( async for _ in self.engine_client.generate(
request_id=request.request_id, request_id=request.request_id,
prompt=TokensPrompt(prompt_token_ids=request.prompt_token_ids), prompt=TokensPrompt(prompt_token_ids=request.prompt_token_ids),
...@@ -67,20 +79,13 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs): ...@@ -67,20 +79,13 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
endpoint = component.endpoint("generate") endpoint = component.endpoint("generate")
async with build_async_engine_client_from_engine_args(engine_args) as engine_client: async with build_async_engine_client_from_engine_args(engine_args) as engine_client:
# This should be replaced with etcd
metadata = engine_client.nixl_metadata metadata = engine_client.nixl_metadata
print(f"Waiting for remote metadata for engine {metadata.engine_id}") metadata_store = NixlMetadataStore("test-nixl", runtime)
remote_metadata: list[NixlMetadata] = [] await metadata_store.put(metadata.engine_id, metadata)
while not remote_metadata:
await asyncio.sleep(1) await endpoint.serve_endpoint(
remote_metadata = find_remote_metadata(metadata.engine_id) RequestHandler(engine_client, metadata_store).generate
print(
f"Found {len(remote_metadata)} remote metadata for engine {metadata.engine_id}"
) )
for remote_metadata in remote_metadata:
await engine_client.add_remote_nixl_metadata(remote_metadata)
await endpoint.serve_endpoint(RequestHandler(engine_client).generate)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -19,7 +19,7 @@ import json ...@@ -19,7 +19,7 @@ import json
import msgspec import msgspec
import uvloop import uvloop
from common import parse_vllm_args, temp_metadata_file from common import NixlMetadataStore, parse_vllm_args
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.multiprocessing.client import EngineClient from vllm.engine.multiprocessing.client import EngineClient
from vllm.entrypoints.openai.api_server import ( from vllm.entrypoints.openai.api_server import (
...@@ -132,15 +132,17 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs): ...@@ -132,15 +132,17 @@ async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
if engine_args.remote_prefill: if engine_args.remote_prefill:
metadata = engine_client.nixl_metadata metadata = engine_client.nixl_metadata
with temp_metadata_file(metadata.engine_id, metadata): metadata_store = NixlMetadataStore("test-nixl", runtime)
await endpoint.serve_endpoint( await metadata_store.put(metadata.engine_id, metadata)
RequestHandler(
model_name="vllm", await endpoint.serve_endpoint(
engine_client=engine_client, RequestHandler(
prefill_client=prefill_client, model_name="vllm",
do_remote_prefill=True, engine_client=engine_client,
).generate prefill_client=prefill_client,
) do_remote_prefill=True,
).generate
)
else: else:
await endpoint.serve_endpoint( await endpoint.serve_endpoint(
RequestHandler( RequestHandler(
......
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
use futures::StreamExt; use futures::StreamExt;
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use pyo3::exceptions::PyStopAsyncIteration; use pyo3::exceptions::PyStopAsyncIteration;
use pyo3::types::PyString; use pyo3::types::PyBytes;
use pyo3::types::{PyDict, PyList, PyString};
use pyo3::IntoPyObjectExt; use pyo3::IntoPyObjectExt;
use pyo3::{exceptions::PyException, prelude::*}; use pyo3::{exceptions::PyException, prelude::*};
use rs::pipeline::network::Ingress; use rs::pipeline::network::Ingress;
...@@ -62,6 +63,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -62,6 +63,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Component>()?; m.add_class::<Component>()?;
m.add_class::<Endpoint>()?; m.add_class::<Endpoint>()?;
m.add_class::<Client>()?; m.add_class::<Client>()?;
m.add_class::<EtcdClient>()?;
m.add_class::<AsyncResponseStream>()?; m.add_class::<AsyncResponseStream>()?;
m.add_class::<llm::kv::KvRouter>()?; m.add_class::<llm::kv::KvRouter>()?;
m.add_class::<llm::kv::KvMetricsPublisher>()?; m.add_class::<llm::kv::KvMetricsPublisher>()?;
...@@ -85,6 +87,12 @@ struct DistributedRuntime { ...@@ -85,6 +87,12 @@ struct DistributedRuntime {
event_loop: PyObject, event_loop: PyObject,
} }
#[pyclass]
#[derive(Clone)]
struct EtcdClient {
inner: rs::transports::etcd::Client,
}
#[pyclass] #[pyclass]
#[derive(Clone)] #[derive(Clone)]
struct CancellationToken { struct CancellationToken {
...@@ -149,6 +157,12 @@ impl DistributedRuntime { ...@@ -149,6 +157,12 @@ impl DistributedRuntime {
}) })
} }
fn etcd_client(&self) -> PyResult<EtcdClient> {
Ok(EtcdClient {
inner: self.inner.etcd_client().clone(),
})
}
fn primary_token(&self) -> CancellationToken { fn primary_token(&self) -> CancellationToken {
let inner = self.inner.runtime().primary_token(); let inner = self.inner.runtime().primary_token();
CancellationToken { inner } CancellationToken { inner }
...@@ -252,6 +266,73 @@ impl Namespace { ...@@ -252,6 +266,73 @@ impl Namespace {
} }
} }
#[pymethods]
impl EtcdClient {
#[pyo3(signature = (key, value, lease_id=None))]
fn kv_create_or_validate<'p>(
&self,
py: Python<'p>,
key: String,
value: Vec<u8>,
lease_id: Option<i64>,
) -> PyResult<Bound<'p, PyAny>> {
let client = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
client
.kv_create_or_validate(key, value, lease_id)
.await
.map_err(to_pyerr)?;
Ok(())
})
}
#[pyo3(signature = (key, value, lease_id=None))]
fn kv_put<'p>(
&self,
py: Python<'p>,
key: String,
value: Vec<u8>,
lease_id: Option<i64>,
) -> PyResult<Bound<'p, PyAny>> {
let client = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
client
.kv_put(key, value, lease_id)
.await
.map_err(to_pyerr)?;
Ok(())
})
}
fn kv_get_prefix<'p>(&self, py: Python<'p>, prefix: String) -> PyResult<Bound<'p, PyAny>> {
let client = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let result = client
.kv_get_prefix(prefix)
.await
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
// Convert Vec<KeyValue> to a list of dictionaries
let py_list = Python::with_gil(|py| {
let list = PyList::empty(py);
for kv in result {
let dict = PyDict::new(py);
dict.set_item("key", String::from_utf8_lossy(kv.key()).to_string())?;
dict.set_item("value", PyBytes::new(py, kv.value()))?;
dict.set_item("create_revision", kv.create_revision())?;
dict.set_item("mod_revision", kv.mod_revision())?;
dict.set_item("version", kv.version())?;
dict.set_item("lease", kv.lease())?;
list.append(dict)?;
}
Ok::<Py<PyList>, PyErr>(list.into())
})?;
Ok(py_list)
})
}
}
#[pymethods] #[pymethods]
impl Client { impl Client {
/// Get list of current endpoints /// Get list of current endpoints
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import AsyncGenerator, AsyncIterator, Callable, List from typing import AsyncGenerator, AsyncIterator, Callable, Dict, List, Optional
class JsonLike: class JsonLike:
""" """
...@@ -37,6 +37,34 @@ class DistributedRuntime: ...@@ -37,6 +37,34 @@ class DistributedRuntime:
""" """
... ...
def etcd_client(self) -> EtcdClient:
"""
Get the `EtcdClient` object
"""
...
class EtcdClient:
"""
Etcd is used for discovery in the DistributedRuntime
"""
async def kv_create_or_validate(self, key: str, value: bytes, lease_id: Optional[int] = None) -> None:
"""
Atomically create a key if it does not exist, or validate the values are identical if the key exists.
"""
...
async def kv_put(self, key: str, value: bytes, lease_id: Optional[int] = None) -> None:
"""
Put a key-value pair into etcd
"""
...
async def kv_get_prefix(self, prefix: str) -> List[Dict[str, JsonLike]]:
"""
Get all keys with a given prefix
"""
...
class Namespace: class Namespace:
""" """
A namespace is a collection of components A namespace is a collection of components
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from triton_distributed._core import DistributedRuntime
async def test_simple_put_get():
# Initialize runtime
loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop)
# Get etcd client
etcd = runtime.etcd_client()
# Write some key-value pairs
test_keys = {
"test/key1": b"value1",
"test/key2": b"value2",
"test/nested/key3": b"value3",
}
# Write each key-value pair
for key, value in test_keys.items():
print(f"Writing {key} = {value!r}")
await etcd.kv_create_or_validate(key, value, None)
print("Successfully wrote all keys to etcd")
# Test kv_put
put_key = "test/put_key"
put_value = b"put_value"
test_keys[put_key] = put_value
print(f"Using kv_put to write {put_key} = {put_value!r}")
await etcd.kv_put(put_key, put_value, None)
# Test kv_get_prefix to read all keys
print("\nReading all keys with prefix 'test/':")
keys_values = await etcd.kv_get_prefix("test/")
for item in keys_values:
print(f"Retrieved {item['key']} = {item['value']!r}")
assert test_keys[item["key"]] == item["value"]
# Verify prefix filtering works
print("\nReading keys with prefix 'test/nested/':")
nested_keys_values = await etcd.kv_get_prefix("test/nested/")
for item in nested_keys_values:
print(f"Retrieved {item['key']} = {item['value']!r}")
assert test_keys[item["key"]] == item["value"]
# Shutdown runtime
runtime.shutdown()
if __name__ == "__main__":
asyncio.run(test_simple_put_get())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment