"pcdet/vscode:/vscode.git/clone" did not exist on "c4033be4f7273c6083e6880ddd51651acb2feead"
Commit 8588e33a authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

feat: Add KV publisher and receiver. Add KV aware routing example.


Signed-off-by: default avatarNeelay Shah <neelays@nvidia.com>
Co-authored-by: default avataraflowers <aflowers@nvidia.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
Co-authored-by: default avatarhongkuanz <hongkuanz@nvidia.com>
Co-authored-by: default avatarNeelay Shah <neelays@nvidia.com>
parent d8aada0b
...@@ -20,6 +20,7 @@ from typing import Any, AsyncGenerator, Callable, Type ...@@ -20,6 +20,7 @@ from typing import Any, AsyncGenerator, Callable, Type
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from triton_distributed_rs._core import DistributedRuntime from triton_distributed_rs._core import DistributedRuntime
from triton_distributed_rs._core import KvRouter as KvRouter
def triton_worker(): def triton_worker():
......
from typing import AsyncGenerator, AsyncIterator, Callable # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import AsyncGenerator, AsyncIterator, Callable, List
class JsonLike: class JsonLike:
""" """
...@@ -54,6 +69,12 @@ class Component: ...@@ -54,6 +69,12 @@ class Component:
""" """
... ...
def event_subject(self, name: str) -> str:
"""
Create an event subject
"""
...
class Endpoint: class Endpoint:
""" """
An Endpoint is a single API endpoint An Endpoint is a single API endpoint
...@@ -74,6 +95,12 @@ class Endpoint: ...@@ -74,6 +95,12 @@ class Endpoint:
""" """
... ...
async def lease_id(self) -> int:
"""
Return primary lease id. Currently, cannot set a different lease id.
"""
...
class Client: class Client:
""" """
A client capable of calling served instances of an endpoint A client capable of calling served instances of an endpoint
...@@ -98,3 +125,22 @@ class Client: ...@@ -98,3 +125,22 @@ class Client:
Pick a specific instance of the endpoint Pick a specific instance of the endpoint
""" """
... ...
class KvRouter:
"""
The runtime object for a distributed NOVA applications
"""
...
def __init__(self, drt: DistributedRuntime, component: Component) -> KvRouter:
"""
Create a `KvRouter` object that is associated with the `component`
"""
def schedule(self, token_ids: List[int], lora_id: int) -> str:
"""
Return the worker id that should handle the given token ids,
exception will be raised if there is no worker available.
"""
...
...@@ -22,14 +22,19 @@ use pyo3::{exceptions::PyException, prelude::*}; ...@@ -22,14 +22,19 @@ use pyo3::{exceptions::PyException, prelude::*};
use rs::pipeline::network::Ingress; use rs::pipeline::network::Ingress;
use std::{fmt::Display, sync::Arc}; use std::{fmt::Display, sync::Arc};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tracing_subscriber::FmtSubscriber;
use triton_distributed::{ use triton_distributed::{
self as rs, self as rs,
pipeline::{EngineStream, ManyOut, SingleIn}, pipeline::{EngineStream, ManyOut, SingleIn},
protocols::annotated::Annotated as RsAnnotated, protocols::annotated::Annotated as RsAnnotated,
traits::DistributedRuntimeProvider,
}; };
use triton_llm::{self as llm_rs};
mod engine; mod engine;
mod llm;
type JsonServerStreamingIngress = type JsonServerStreamingIngress =
Ingress<SingleIn<serde_json::Value>, ManyOut<RsAnnotated<serde_json::Value>>>; Ingress<SingleIn<serde_json::Value>, ManyOut<RsAnnotated<serde_json::Value>>>;
...@@ -43,6 +48,16 @@ const DEFAULT_ANNOTATED_SETTING: Option<bool> = Some(true); ...@@ -43,6 +48,16 @@ const DEFAULT_ANNOTATED_SETTING: Option<bool> = Some(true);
/// import the module. /// import the module.
#[pymodule] #[pymodule]
fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
// Sets up RUST_LOG environment variable for logging through the python-wheel
// Example: RUST_LOG=debug python3 -m ...
let subscriber = FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.finish();
tracing::subscriber::set_global_default(subscriber)
.expect("setting default subscriber failed");
m.add_class::<DistributedRuntime>()?; m.add_class::<DistributedRuntime>()?;
m.add_class::<CancellationToken>()?; m.add_class::<CancellationToken>()?;
m.add_class::<Namespace>()?; m.add_class::<Namespace>()?;
...@@ -50,6 +65,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -50,6 +65,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Endpoint>()?; m.add_class::<Endpoint>()?;
m.add_class::<Client>()?; m.add_class::<Client>()?;
m.add_class::<AsyncResponseStream>()?; m.add_class::<AsyncResponseStream>()?;
m.add_class::<llm::kv::KvRouter>()?;
engine::add_to_module(m)?; engine::add_to_module(m)?;
...@@ -72,31 +88,35 @@ struct DistributedRuntime { ...@@ -72,31 +88,35 @@ struct DistributedRuntime {
event_loop: PyObject, event_loop: PyObject,
} }
#[derive(Clone)]
#[pyclass] #[pyclass]
#[derive(Clone)]
struct CancellationToken { struct CancellationToken {
inner: rs::CancellationToken, inner: rs::CancellationToken,
} }
#[pyclass] #[pyclass]
#[derive(Clone)]
struct Namespace { struct Namespace {
inner: rs::component::Namespace, inner: rs::component::Namespace,
event_loop: PyObject, event_loop: PyObject,
} }
#[pyclass] #[pyclass]
#[derive(Clone)]
struct Component { struct Component {
inner: rs::component::Component, inner: rs::component::Component,
event_loop: PyObject, event_loop: PyObject,
} }
#[pyclass] #[pyclass]
#[derive(Clone)]
struct Endpoint { struct Endpoint {
inner: rs::component::Endpoint, inner: rs::component::Endpoint,
event_loop: PyObject, event_loop: PyObject,
} }
#[pyclass] #[pyclass]
#[derive(Clone)]
struct Client { struct Client {
inner: rs::component::Client<serde_json::Value, serde_json::Value>, inner: rs::component::Client<serde_json::Value, serde_json::Value>,
} }
...@@ -105,18 +125,18 @@ struct Client { ...@@ -105,18 +125,18 @@ struct Client {
impl DistributedRuntime { impl DistributedRuntime {
#[new] #[new]
fn new(event_loop: PyObject) -> PyResult<Self> { fn new(event_loop: PyObject) -> PyResult<Self> {
let rt = rs::Worker::from_settings().map_err(to_pyerr)?; let worker = rs::Worker::from_settings().map_err(to_pyerr)?;
INIT.get_or_try_init(|| { INIT.get_or_try_init(|| {
let primary = rt.tokio_runtime()?; let primary = worker.tokio_runtime()?;
pyo3_async_runtimes::tokio::init_with_runtime(primary) pyo3_async_runtimes::tokio::init_with_runtime(primary)
.map_err(|e| rs::error!("failed to initialize pyo3 static runtime: {:?}", e))?; .map_err(|e| rs::error!("failed to initialize pyo3 static runtime: {:?}", e))?;
rs::OK(()) rs::OK(())
}) })
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
let runtime = rt.runtime().clone(); let runtime = worker.runtime().clone();
let inner = rt let inner = worker
.runtime() .runtime()
.secondary() .secondary()
.block_on(rs::DistributedRuntime::from_settings(runtime)) .block_on(rs::DistributedRuntime::from_settings(runtime))
...@@ -183,6 +203,10 @@ impl Component { ...@@ -183,6 +203,10 @@ impl Component {
Ok(()) Ok(())
}) })
} }
fn event_subject(&self, name: String) -> String {
self.inner.event_subject(name)
}
} }
#[pymethods] #[pymethods]
...@@ -214,6 +238,10 @@ impl Endpoint { ...@@ -214,6 +238,10 @@ impl Endpoint {
Ok(Client { inner: client }) Ok(Client { inner: client })
}) })
} }
fn lease_id(&self) -> i64 {
self.inner.drt().primary_lease().id()
}
} }
#[pymethods] #[pymethods]
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
pub mod kv;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::*;
#[pyclass]
pub(crate) struct KvRouter {
inner: Arc<llm_rs::kv_router::KvRouter>,
}
#[pymethods]
impl KvRouter {
#[new]
fn new(drt: DistributedRuntime, component: Component) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async {
let inner = llm_rs::kv_router::KvRouter::from_runtime(
drt.inner.clone(),
component.inner.clone(),
)
.await
.map_err(to_pyerr)?;
Ok(Self { inner })
})
}
fn schedule<'p>(
&self,
py: Python<'p>,
token_ids: Vec<u32>,
lora_id: u64,
) -> PyResult<Bound<'p, PyAny>> {
let router = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let uuid = router
.schedule(&token_ids, lora_id)
.await
.map_err(to_pyerr)?;
Ok(uuid.to_string())
})
}
}
...@@ -77,7 +77,12 @@ class TritonCoreOperator(Operator): ...@@ -77,7 +77,12 @@ class TritonCoreOperator(Operator):
if repository: if repository:
self._triton_core.register_model_repository(repository) self._triton_core.register_model_repository(repository)
parameter_config = self._parameters.get("config", None) parameter_config = self._parameters.get("config", {})
if "parameters" not in parameter_config:
parameter_config["parameters"] = {}
parameter_config["parameters"]["component_id"] = {
"string_value": f"{self._request_plane.component_id}"
}
model_config = None model_config = None
...@@ -92,17 +97,14 @@ class TritonCoreOperator(Operator): ...@@ -92,17 +97,14 @@ class TritonCoreOperator(Operator):
except Exception: except Exception:
pass pass
if parameter_config and model_config: parameter_config = json_format.Parse(
model_config.MergeFrom( json.dumps(parameter_config), model_config_pb2.ModelConfig()
json_format.Parse( )
json.dumps(parameter_config), model_config_pb2.ModelConfig() if model_config:
) model_config.MergeFrom(parameter_config)
)
model_config = {"config": json_format.MessageToJson(model_config)}
elif parameter_config:
model_config = {"config": parameter_config}
else: else:
model_config = None model_config = parameter_config
model_config = {"config": json_format.MessageToJson(model_config)}
self._triton_core_model = self._triton_core.load(self._name, model_config) self._triton_core_model = self._triton_core.load(self._name, model_config)
@staticmethod @staticmethod
......
...@@ -41,6 +41,56 @@ dependencies = [ ...@@ -41,6 +41,56 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "anstream"
version = "0.6.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b"
dependencies = [
"anstyle",
"anstyle-parse",
"anstyle-query",
"anstyle-wincon",
"colorchoice",
"is_terminal_polyfill",
"utf8parse",
]
[[package]]
name = "anstyle"
version = "1.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9"
[[package]]
name = "anstyle-parse"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9"
dependencies = [
"utf8parse",
]
[[package]]
name = "anstyle-query"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c"
dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "anstyle-wincon"
version = "3.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e"
dependencies = [
"anstyle",
"once_cell",
"windows-sys 0.59.0",
]
[[package]] [[package]]
name = "anyhow" name = "anyhow"
version = "1.0.95" version = "1.0.95"
...@@ -363,6 +413,12 @@ dependencies = [ ...@@ -363,6 +413,12 @@ dependencies = [
"windows-targets", "windows-targets",
] ]
[[package]]
name = "colorchoice"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990"
[[package]] [[package]]
name = "const-oid" name = "const-oid"
version = "0.9.6" version = "0.9.6"
...@@ -701,6 +757,29 @@ dependencies = [ ...@@ -701,6 +757,29 @@ dependencies = [
"syn 2.0.98", "syn 2.0.98",
] ]
[[package]]
name = "env_filter"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0"
dependencies = [
"log",
"regex",
]
[[package]]
name = "env_logger"
version = "0.11.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcaee3d8e3cfc3fd92428d477bc97fc29ec8716d180c0d74c643bb26166660e0"
dependencies = [
"anstream",
"anstyle",
"env_filter",
"humantime",
"log",
]
[[package]] [[package]]
name = "equivalent" name = "equivalent"
version = "1.0.1" version = "1.0.1"
...@@ -870,6 +949,12 @@ version = "0.3.31" ...@@ -870,6 +949,12 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]] [[package]]
name = "futures-util" name = "futures-util"
version = "0.3.31" version = "0.3.31"
...@@ -927,6 +1012,12 @@ version = "0.31.1" ...@@ -927,6 +1012,12 @@ version = "0.31.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
[[package]]
name = "glob"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
[[package]] [[package]]
name = "h2" name = "h2"
version = "0.4.7" version = "0.4.7"
...@@ -1272,6 +1363,12 @@ dependencies = [ ...@@ -1272,6 +1363,12 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "is_terminal_polyfill"
version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]] [[package]]
name = "itertools" name = "itertools"
version = "0.13.0" version = "0.13.0"
...@@ -1756,6 +1853,15 @@ dependencies = [ ...@@ -1756,6 +1853,15 @@ dependencies = [
"syn 2.0.98", "syn 2.0.98",
] ]
[[package]]
name = "proc-macro-crate"
version = "3.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b"
dependencies = [
"toml_edit",
]
[[package]] [[package]]
name = "proc-macro-error-attr2" name = "proc-macro-error-attr2"
version = "2.0.0" version = "2.0.0"
...@@ -1985,6 +2091,12 @@ version = "0.8.5" ...@@ -1985,6 +2091,12 @@ version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "relative-path"
version = "1.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2"
[[package]] [[package]]
name = "ring" name = "ring"
version = "0.17.8" version = "0.17.8"
...@@ -2000,6 +2112,36 @@ dependencies = [ ...@@ -2000,6 +2112,36 @@ dependencies = [
"windows-sys 0.52.0", "windows-sys 0.52.0",
] ]
[[package]]
name = "rstest"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0a2c585be59b6b5dd66a9d2084aa1d8bd52fbdb806eafdeffb52791147862035"
dependencies = [
"futures",
"futures-timer",
"rstest_macros",
"rustc_version",
]
[[package]]
name = "rstest_macros"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "825ea780781b15345a146be27eaefb05085e337e869bff01b4306a4fd4a9ad5a"
dependencies = [
"cfg-if 1.0.0",
"glob",
"proc-macro-crate",
"proc-macro2",
"quote",
"regex",
"relative-path",
"rustc_version",
"syn 2.0.98",
"unicode-ident",
]
[[package]] [[package]]
name = "rustc-demangle" name = "rustc-demangle"
version = "0.1.24" version = "0.1.24"
...@@ -2794,6 +2936,7 @@ dependencies = [ ...@@ -2794,6 +2936,7 @@ dependencies = [
"derive_builder", "derive_builder",
"educe", "educe",
"either", "either",
"env_logger",
"etcd-client", "etcd-client",
"figment", "figment",
"futures", "futures",
...@@ -2807,6 +2950,7 @@ dependencies = [ ...@@ -2807,6 +2950,7 @@ dependencies = [
"prometheus", "prometheus",
"rand", "rand",
"regex", "regex",
"rstest",
"serde", "serde",
"serde_json", "serde_json",
"socket2", "socket2",
...@@ -2888,6 +3032,12 @@ version = "1.0.4" ...@@ -2888,6 +3032,12 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
[[package]]
name = "utf8parse"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]] [[package]]
name = "uuid" name = "uuid"
version = "1.12.1" version = "1.12.1"
......
...@@ -74,3 +74,5 @@ rand = { version = "0.8"} ...@@ -74,3 +74,5 @@ rand = { version = "0.8"}
[dev-dependencies] [dev-dependencies]
assert_matches = "1.5.0" assert_matches = "1.5.0"
env_logger = "0.11"
rstest = "0.23.0"
...@@ -125,10 +125,22 @@ impl Component { ...@@ -125,10 +125,22 @@ impl Component {
format!("{}/components/{}", self.namespace, self.name) format!("{}/components/{}", self.namespace, self.name)
} }
pub fn drt(&self) -> &DistributedRuntime {
&self.drt
}
fn slug(&self) -> Slug { fn slug(&self) -> Slug {
Slug::from_string(self.etcd_path()) Slug::from_string(self.etcd_path())
} }
pub fn service_name(&self) -> String {
self.slug().to_string()
}
pub fn event_subject(&self, name: impl AsRef<str>) -> String {
format!("{}.events.{}", self.slug(), name.as_ref())
}
pub fn endpoint(&self, endpoint: impl Into<String>) -> Endpoint { pub fn endpoint(&self, endpoint: impl Into<String>) -> Endpoint {
Endpoint { Endpoint {
component: self.clone(), component: self.clone(),
...@@ -189,6 +201,10 @@ impl Endpoint { ...@@ -189,6 +201,10 @@ impl Endpoint {
&self.name &self.name
} }
pub fn component(&self) -> &Component {
&self.component
}
pub fn etcd_path(&self) -> String { pub fn etcd_path(&self) -> String {
format!("{}/{}", self.component.etcd_path(), self.name) format!("{}/{}", self.component.etcd_path(), self.name)
} }
......
...@@ -24,6 +24,7 @@ pub struct EndpointConfig { ...@@ -24,6 +24,7 @@ pub struct EndpointConfig {
#[builder(private)] #[builder(private)]
endpoint: Endpoint, endpoint: Endpoint,
// todo: move lease to component/service
/// Lease /// Lease
#[educe(Debug(ignore))] #[educe(Debug(ignore))]
#[builder(default)] #[builder(default)]
......
...@@ -93,3 +93,24 @@ impl ServiceConfigBuilder { ...@@ -93,3 +93,24 @@ impl ServiceConfigBuilder {
Self::default().component(component) Self::default().component(component)
} }
} }
// // Wrap the optional user callback method in a closure that appends the lease_id to the response
// fn wrap_callback(
// callback: Option<Box<dyn FnMut(String, Stats) -> Value + Send + Sync>>,
// lease_id: i64,
// ) -> Box<dyn FnMut(String, Stats) -> Value + Send + Sync> {
// let callback = Arc::new(Mutex::new(callback)); // Wrap in Arc<Mutex> for shared access
// Box::new(move |subject: String, stats: Stats| -> Value {
// let mut callback_lock = callback.lock().unwrap();
// if let Some(cb) = callback_lock.as_mut() {
// let mut result = cb(subject, stats); // Call the user-defined callback
// if let Some(obj) = result.as_object_mut() {
// obj.insert("lease_id".to_string(), json!(lease_id)); // Append lease_id
// }
// result
// } else {
// json!({ "error": "callback not set", "lease_id": lease_id }) // Default response
// }
// })
// }
...@@ -107,7 +107,7 @@ impl DistributedRuntime { ...@@ -107,7 +107,7 @@ impl DistributedRuntime {
ServiceClient::new(self.nats_client.clone()) ServiceClient::new(self.nats_client.clone())
} }
pub(crate) async fn tcp_server(&self) -> Result<Arc<tcp::server::TcpStreamServer>> { pub async fn tcp_server(&self) -> Result<Arc<tcp::server::TcpStreamServer>> {
Ok(self Ok(self
.tcp_server .tcp_server
.get_or_try_init(async move { .get_or_try_init(async move {
......
...@@ -34,8 +34,7 @@ pub struct ServiceClient { ...@@ -34,8 +34,7 @@ pub struct ServiceClient {
} }
impl ServiceClient { impl ServiceClient {
#[allow(dead_code)] pub fn new(nats_client: nats::Client) -> Self {
pub(crate) fn new(nats_client: nats::Client) -> Self {
ServiceClient { nats_client } ServiceClient { nats_client }
} }
} }
...@@ -85,9 +84,13 @@ impl ServiceClient { ...@@ -85,9 +84,13 @@ impl ServiceClient {
Ok(response) Ok(response)
} }
pub async fn collect_services(&self, service_name: &str) -> Result<ServiceSet> { pub async fn collect_services(
&self,
service_name: &str,
duration: Duration,
) -> Result<ServiceSet> {
let mut sub = self.nats_client.service_subscriber(service_name).await?; let mut sub = self.nats_client.service_subscriber(service_name).await?;
let deadline = tokio::time::Instant::now() + Duration::from_secs(1); let deadline = tokio::time::Instant::now() + duration;
let services: Vec<Result<ServiceInfo>> = try_stream! { let services: Vec<Result<ServiceInfo>> = try_stream! {
while let Ok(Some(message)) = tokio::time::timeout_at(deadline, sub.next()).await { while let Ok(Some(message)) = tokio::time::timeout_at(deadline, sub.next()).await {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment