Commit 88ad3425 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Python decorator dynamo_worker takes optional `static` parameter without etcd (#494)

Adds `@dynamo_worker(static = True)` to create a static worker which has a predictable name and hence does not require discovery or `etcd` to be running. There can only be a single static worker per namespace / component / endpoint trio.

This contrasts with the default dynamic `dynamo_worker` endpoints we have now, which get a unique random name (based on namespace/component/endpoint), and are discovered by ingress components using etcd.

Also change the hello_world example to use `dynamo_worker(static = True)` so that it is exercised and demonstrated somewhere.

For NIM.
parent bd8f0804
...@@ -89,12 +89,14 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -89,12 +89,14 @@ async fn app(runtime: Runtime) -> Result<()> {
drt: distributed.clone(), drt: distributed.clone(),
}); });
let etcd_client = distributed.etcd_client(); if let Some(etcd_client) = distributed.etcd_client() {
let models_watcher: PrefixWatcher = etcd_client.kv_get_and_watch_prefix(etcd_path).await?; let models_watcher: PrefixWatcher =
etcd_client.kv_get_and_watch_prefix(etcd_path).await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve();
let watcher_task = tokio::spawn(model_watcher(state, receiver)); let (_prefix, _watcher, receiver) = models_watcher.dissolve();
watcher_tasks.push(watcher_task); let watcher_task = tokio::spawn(model_watcher(state, receiver));
watcher_tasks.push(watcher_task);
}
} }
// Run the service // Run the service
......
...@@ -126,10 +126,11 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -126,10 +126,11 @@ async fn app(runtime: Runtime) -> Result<()> {
let key = format!("{}/instance", component.etcd_path()); let key = format!("{}/instance", component.etcd_path());
tracing::debug!("Creating unique instance of Count at {key}"); tracing::debug!("Creating unique instance of Count at {key}");
drt.etcd_client() drt.etcd_client()
.expect("Unreachable because of DistributedRuntime::from_settings above")
.kv_create( .kv_create(
key, key,
serde_json::to_vec_pretty(&config)?, serde_json::to_vec_pretty(&config)?,
Some(drt.primary_lease().id()), Some(drt.primary_lease().unwrap().id()),
) )
.await .await
.context("Unable to create unique instance of Count; possibly one already exists")?; .context("Unable to create unique instance of Count; possibly one already exists")?;
...@@ -141,7 +142,8 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -141,7 +142,8 @@ async fn app(runtime: Runtime) -> Result<()> {
let service_subject = target_endpoint.subject(); let service_subject = target_endpoint.subject();
tracing::info!("Scraping endpoint {service_path} for stats"); tracing::info!("Scraping endpoint {service_path} for stats");
let token = drt.primary_lease().child_token(); // Safety: DistributedRuntime::from_settings ensures this is Some
let token = drt.primary_lease().unwrap().child_token();
let event_name = format!("l2c.{}.{}", config.component_name, config.endpoint_name); let event_name = format!("l2c.{}.{}", config.component_name, config.endpoint_name);
// Initialize Prometheus metrics with the selected mode // Initialize Prometheus metrics with the selected mode
......
...@@ -64,6 +64,8 @@ class NixlMetadataStore: ...@@ -64,6 +64,8 @@ class NixlMetadataStore:
self._cached: dict[str, NixlMetadata] = {} self._cached: dict[str, NixlMetadata] = {}
self._client = runtime.etcd_client() self._client = runtime.etcd_client()
if self._client is None:
raise Exception("Cannot be used with static workers")
self._key_prefix = f"{self._namespace}/{NixlMetadataStore.NIXL_METADATA_KEY}" self._key_prefix = f"{self._namespace}/{NixlMetadataStore.NIXL_METADATA_KEY}"
async def put(self, engine_id, metadata: NixlMetadata): async def put(self, engine_id, metadata: NixlMetadata):
......
...@@ -43,8 +43,6 @@ pub async fn run( ...@@ -43,8 +43,6 @@ pub async fn run(
let cancel_token = runtime.primary_token().clone(); let cancel_token = runtime.primary_token().clone();
let endpoint_id: Endpoint = path.parse()?; let endpoint_id: Endpoint = path.parse()?;
let etcd_client = distributed.etcd_client();
let (ingress, service_name) = match engine_config { let (ingress, service_name) = match engine_config {
EngineConfig::StaticFull { EngineConfig::StaticFull {
service_name, service_name,
...@@ -95,15 +93,18 @@ pub async fn run( ...@@ -95,15 +93,18 @@ pub async fn run(
.create() .create()
.await? .await?
.endpoint(endpoint_id.name); .endpoint(endpoint_id.name);
let network_name = endpoint.subject();
tracing::debug!("Registering with etcd as {network_name}"); if let Some(etcd_client) = distributed.etcd_client() {
etcd_client let network_name = endpoint.subject();
.kv_create( tracing::debug!("Registering with etcd as {network_name}");
network_name.clone(), etcd_client
serde_json::to_vec_pretty(&model_registration)?, .kv_create(
Some(etcd_client.lease_id()), network_name.clone(),
) serde_json::to_vec_pretty(&model_registration)?,
.await?; Some(etcd_client.lease_id()),
)
.await?;
}
let rt_fut = endpoint.endpoint_builder().handler(ingress).start(); let rt_fut = endpoint.endpoint_builder().handler(ingress).start();
tokio::select! { tokio::select! {
......
...@@ -47,26 +47,33 @@ pub async fn run( ...@@ -47,26 +47,33 @@ pub async fn run(
.build()?; .build()?;
match engine_config { match engine_config {
EngineConfig::Dynamic(endpoint) => { EngineConfig::Dynamic(endpoint) => {
// This will attempt to connect to NATS and etcd
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?; let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
match distributed_runtime.etcd_client() {
Some(etcd_client) => {
// This will attempt to connect to NATS and etcd
let component = distributed_runtime let component = distributed_runtime
.namespace(endpoint.namespace)? .namespace(endpoint.namespace)?
.component(endpoint.component)?; .component(endpoint.component)?;
let network_prefix = component.service_name(); let network_prefix = component.service_name();
// Listen for models registering themselves in etcd, add them to HTTP service // Listen for models registering themselves in etcd, add them to HTTP service
let state = Arc::new(discovery::ModelWatchState { let state = Arc::new(discovery::ModelWatchState {
prefix: network_prefix.clone(), prefix: network_prefix.clone(),
model_type: ModelType::Chat, model_type: ModelType::Chat,
manager: http_service.model_manager().clone(), manager: http_service.model_manager().clone(),
drt: distributed_runtime.clone(), drt: distributed_runtime.clone(),
}); });
tracing::info!("Waiting for remote model at {network_prefix}"); tracing::info!("Waiting for remote model at {network_prefix}");
let etcd_client = distributed_runtime.etcd_client(); let models_watcher =
let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?; etcd_client.kv_get_and_watch_prefix(network_prefix).await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve(); let (_prefix, _watcher, receiver) = models_watcher.dissolve();
let _watcher_task = tokio::spawn(discovery::model_watcher(state, receiver)); let _watcher_task = tokio::spawn(discovery::model_watcher(state, receiver));
}
None => {
// Static endpoints don't need discovery
}
}
} }
EngineConfig::StaticFull { EngineConfig::StaticFull {
service_name, service_name,
......
...@@ -269,7 +269,9 @@ async fn add_model( ...@@ -269,7 +269,9 @@ async fn add_model(
model_type.as_str(), model_type.as_str(),
model_name model_name
); );
let etcd_client = distributed.etcd_client(); let etcd_client = distributed
.etcd_client()
.expect("unreachable: llmctl is only useful with dynamic workers");
// check if model already exists // check if model already exists
let kvs = etcd_client.kv_get_prefix(&path).await?; let kvs = etcd_client.kv_get_prefix(&path).await?;
...@@ -321,7 +323,9 @@ async fn list_single_model( ...@@ -321,7 +323,9 @@ async fn list_single_model(
); );
let mut models = Vec::new(); let mut models = Vec::new();
let etcd_client = distributed.etcd_client(); let etcd_client = distributed
.etcd_client()
.expect("llmctl is only useful for dynamic workers");
let kvs = etcd_client.kv_get_prefix(&path).await?; let kvs = etcd_client.kv_get_prefix(&path).await?;
for kv in kvs { for kv in kvs {
...@@ -364,7 +368,9 @@ async fn list_models( ...@@ -364,7 +368,9 @@ async fn list_models(
for mt in model_types { for mt in model_types {
let prefix = format!("{}/models/{}/", component.etcd_path(), mt.as_str(),); let prefix = format!("{}/models/{}/", component.etcd_path(), mt.as_str(),);
let etcd_client = distributed.etcd_client(); let etcd_client = distributed
.etcd_client()
.expect("llmctl is only useful with dynamic workers");
let kvs = etcd_client.kv_get_prefix(&prefix).await?; let kvs = etcd_client.kv_get_prefix(&prefix).await?;
for kv in kvs { for kv in kvs {
...@@ -424,7 +430,11 @@ async fn remove_model( ...@@ -424,7 +430,11 @@ async fn remove_model(
log::debug!("deleting key: {}", prefix); log::debug!("deleting key: {}", prefix);
// get the kvs from etcd // get the kvs from etcd
let mut kv_client = distributed.etcd_client().etcd_client().kv_client(); let mut kv_client = distributed
.etcd_client()
.expect("llmctl is only useful with dynamic workers")
.etcd_client()
.kv_client();
match kv_client.delete(prefix.as_bytes(), None).await { match kv_client.delete(prefix.as_bytes(), None).await {
Ok(_response) => { Ok(_response) => {
println!( println!(
......
...@@ -20,7 +20,7 @@ import uvloop ...@@ -20,7 +20,7 @@ import uvloop
from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_worker
@dynamo_worker() @dynamo_worker(static=True)
async def worker(runtime: DistributedRuntime): async def worker(runtime: DistributedRuntime):
await init(runtime, "dynamo") await init(runtime, "dynamo")
......
...@@ -29,10 +29,11 @@ def random_string(length=10): ...@@ -29,10 +29,11 @@ def random_string(length=10):
return "".join(random.choices(chars, k=length)) return "".join(random.choices(chars, k=length))
@dynamo_worker() @dynamo_worker(static=True)
async def worker(runtime: DistributedRuntime): async def worker(runtime: DistributedRuntime):
ns = random_string() ns = random_string()
task = asyncio.create_task(server_init(runtime, ns)) task = asyncio.create_task(server_init(runtime, ns))
await asyncio.sleep(0.1) # let the server start
await client_init(runtime, ns) await client_init(runtime, ns)
runtime.shutdown() runtime.shutdown()
await task await task
......
...@@ -31,7 +31,7 @@ class RequestHandler: ...@@ -31,7 +31,7 @@ class RequestHandler:
yield char yield char
@dynamo_worker() @dynamo_worker(static=True)
async def worker(runtime: DistributedRuntime): async def worker(runtime: DistributedRuntime):
await init(runtime, "dynamo") await init(runtime, "dynamo")
......
...@@ -144,7 +144,7 @@ struct Client { ...@@ -144,7 +144,7 @@ struct Client {
#[pymethods] #[pymethods]
impl DistributedRuntime { impl DistributedRuntime {
#[new] #[new]
fn new(event_loop: PyObject) -> PyResult<Self> { fn new(event_loop: PyObject, is_static: bool) -> PyResult<Self> {
let worker = rs::Worker::from_settings().map_err(to_pyerr)?; let worker = rs::Worker::from_settings().map_err(to_pyerr)?;
INIT.get_or_try_init(|| { INIT.get_or_try_init(|| {
let primary = worker.tokio_runtime()?; let primary = worker.tokio_runtime()?;
...@@ -156,11 +156,17 @@ impl DistributedRuntime { ...@@ -156,11 +156,17 @@ impl DistributedRuntime {
let runtime = worker.runtime().clone(); let runtime = worker.runtime().clone();
let inner = worker let inner =
.runtime() if is_static {
.secondary() runtime.secondary().block_on(
.block_on(rs::DistributedRuntime::from_settings(runtime)) rs::DistributedRuntime::from_settings_without_discovery(runtime),
.map_err(to_pyerr)?; )
} else {
runtime
.secondary()
.block_on(rs::DistributedRuntime::from_settings(runtime))
};
let inner = inner.map_err(to_pyerr)?;
Ok(DistributedRuntime { inner, event_loop }) Ok(DistributedRuntime { inner, event_loop })
} }
...@@ -172,10 +178,11 @@ impl DistributedRuntime { ...@@ -172,10 +178,11 @@ impl DistributedRuntime {
}) })
} }
fn etcd_client(&self) -> PyResult<EtcdClient> { fn etcd_client(&self) -> PyResult<Option<EtcdClient>> {
Ok(EtcdClient { match self.inner.etcd_client().clone() {
inner: self.inner.etcd_client().clone(), Some(etcd_client) => Ok(Some(EtcdClient { inner: etcd_client })),
}) None => Ok(None),
}
} }
fn primary_token(&self) -> CancellationToken { fn primary_token(&self) -> CancellationToken {
...@@ -262,7 +269,11 @@ impl Endpoint { ...@@ -262,7 +269,11 @@ impl Endpoint {
} }
fn lease_id(&self) -> i64 { fn lease_id(&self) -> i64 {
self.inner.drt().primary_lease().id() self.inner
.drt()
.primary_lease()
.map(|l| l.id())
.unwrap_or(0)
} }
} }
...@@ -348,7 +359,7 @@ impl EtcdClient { ...@@ -348,7 +359,7 @@ impl EtcdClient {
impl Client { impl Client {
/// Get list of current endpoints /// Get list of current endpoints
fn endpoint_ids(&self) -> Vec<i64> { fn endpoint_ids(&self) -> Vec<i64> {
self.inner.endpoint_ids().borrow().clone() self.inner.endpoint_ids()
} }
fn wait_for_endpoints<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> { fn wait_for_endpoints<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
...@@ -366,7 +377,11 @@ impl Client { ...@@ -366,7 +377,11 @@ impl Client {
request: PyObject, request: PyObject,
annotated: Option<bool>, annotated: Option<bool>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
self.random(py, request, annotated) if self.inner.is_static() {
self.r#static(py, request, annotated)
} else {
self.random(py, request, annotated)
}
} }
/// Send a request to the next endpoint in a round-robin fashion. /// Send a request to the next endpoint in a round-robin fashion.
...@@ -446,6 +461,32 @@ impl Client { ...@@ -446,6 +461,32 @@ impl Client {
}) })
}) })
} }
/// Directly send a request to a pre-defined static worker
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING))]
fn r#static<'p>(
&self,
py: Python<'p>,
request: PyObject,
annotated: Option<bool>,
) -> PyResult<Bound<'p, PyAny>> {
let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
let annotated = annotated.unwrap_or(false);
let (tx, rx) = tokio::sync::mpsc::channel(32);
let client = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = client.r#static(request.into()).await.map_err(to_pyerr)?;
tokio::spawn(process_stream(stream, tx));
Ok(AsyncResponseStream {
rx: Arc::new(Mutex::new(rx)),
annotated,
})
})
}
} }
async fn process_stream( async fn process_stream(
......
...@@ -43,9 +43,9 @@ class DistributedRuntime: ...@@ -43,9 +43,9 @@ class DistributedRuntime:
""" """
... ...
def etcd_client(self) -> EtcdClient: def etcd_client(self) -> Optional[EtcdClient]:
""" """
Get the `EtcdClient` object Get the `EtcdClient` object. Not available for static workers.
""" """
... ...
......
...@@ -30,12 +30,12 @@ from dynamo._core import ModelDeploymentCard as ModelDeploymentCard ...@@ -30,12 +30,12 @@ from dynamo._core import ModelDeploymentCard as ModelDeploymentCard
from dynamo._core import OAIChatPreprocessor as OAIChatPreprocessor from dynamo._core import OAIChatPreprocessor as OAIChatPreprocessor
def dynamo_worker(): def dynamo_worker(static=False):
def decorator(func): def decorator(func):
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop) runtime = DistributedRuntime(loop, static)
await func(runtime, *args, **kwargs) await func(runtime, *args, **kwargs)
......
...@@ -24,7 +24,7 @@ from dynamo._core import DistributedRuntime ...@@ -24,7 +24,7 @@ from dynamo._core import DistributedRuntime
async def test_simple_put_get(): async def test_simple_put_get():
# Initialize runtime # Initialize runtime
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop) runtime = DistributedRuntime(loop, False)
# Get etcd client # Get etcd client
etcd = runtime.etcd_client() etcd = runtime.etcd_client()
......
...@@ -56,7 +56,7 @@ def setup_and_teardown(): ...@@ -56,7 +56,7 @@ def setup_and_teardown():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
async def distributed_runtime(): async def distributed_runtime():
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
return DistributedRuntime(loop) return DistributedRuntime(loop, False)
# TODO Figure out how to test with different kv_block_size # TODO Figure out how to test with different kv_block_size
......
...@@ -38,11 +38,14 @@ impl DisaggRouterConf { ...@@ -38,11 +38,14 @@ impl DisaggRouterConf {
pub async fn from_etcd_with_watcher( pub async fn from_etcd_with_watcher(
drt: Arc<DistributedRuntime>, drt: Arc<DistributedRuntime>,
model_name: &str, model_name: &str,
) -> Result<(Self, watch::Receiver<Self>), Box<dyn std::error::Error>> { ) -> anyhow::Result<(Self, watch::Receiver<Self>)> {
let etcd_key = format!("public/components/disagg_router/models/chat/{}", model_name); let etcd_key = format!("public/components/disagg_router/models/chat/{}", model_name);
// Get the initial value if it exists // Get the initial value if it exists
let initial_config = match drt.etcd_client().kv_get_prefix(&etcd_key).await { let Some(etcd_client) = drt.etcd_client() else {
anyhow::bail!("Static components don't have an etcd client");
};
let initial_config = match etcd_client.kv_get_prefix(&etcd_key).await {
Ok(kvs) => { Ok(kvs) => {
if let Some(kv) = kvs.first() { if let Some(kv) = kvs.first() {
match serde_json::from_slice::<DisaggRouterConf>(kv.value()) { match serde_json::from_slice::<DisaggRouterConf>(kv.value()) {
...@@ -81,7 +84,7 @@ impl DisaggRouterConf { ...@@ -81,7 +84,7 @@ impl DisaggRouterConf {
let (watch_tx, watch_rx) = watch::channel(initial_config.clone()); let (watch_tx, watch_rx) = watch::channel(initial_config.clone());
// Set up the watcher after getting the initial value // Set up the watcher after getting the initial value
let prefix_watcher = drt.etcd_client().kv_get_and_watch_prefix(&etcd_key).await?; let prefix_watcher = etcd_client.kv_get_and_watch_prefix(&etcd_key).await?;
let (key, _watcher, mut kv_event_rx) = prefix_watcher.dissolve(); let (key, _watcher, mut kv_event_rx) = prefix_watcher.dissolve();
// Spawn background task to watch for config changes // Spawn background task to watch for config changes
...@@ -160,7 +163,7 @@ impl DisaggregatedRouter { ...@@ -160,7 +163,7 @@ impl DisaggregatedRouter {
drt: Arc<DistributedRuntime>, drt: Arc<DistributedRuntime>,
model_name: String, model_name: String,
default_max_local_prefill_length: i32, default_max_local_prefill_length: i32,
) -> Result<Self, Box<dyn std::error::Error>> { ) -> anyhow::Result<Self> {
let (mut config, watcher) = let (mut config, watcher) =
DisaggRouterConf::from_etcd_with_watcher(drt, &model_name).await?; DisaggRouterConf::from_etcd_with_watcher(drt, &model_name).await?;
......
...@@ -74,7 +74,11 @@ impl KvRouter { ...@@ -74,7 +74,11 @@ impl KvRouter {
block_size: usize, block_size: usize,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>, selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
) -> Result<Arc<Self>> { ) -> Result<Arc<Self>> {
let cancellation_token = component.drt().primary_lease().primary_token(); let cancellation_token = component
.drt()
.primary_lease()
.expect("Cannot KV route static workers")
.primary_token();
let metrics_aggregator = let metrics_aggregator =
KvMetricsAggregator::new(component.clone(), cancellation_token.clone()).await; KvMetricsAggregator::new(component.clone(), cancellation_token.clone()).await;
......
...@@ -118,6 +118,10 @@ pub struct Component { ...@@ -118,6 +118,10 @@ pub struct Component {
/// Namespace /// Namespace
#[builder(setter(into))] #[builder(setter(into))]
namespace: Namespace, namespace: Namespace,
// A static component's endpoints cannot be discovered via etcd, they are
// fixed at startup time.
is_static: bool,
} }
impl std::fmt::Display for Component { impl std::fmt::Display for Component {
...@@ -144,7 +148,8 @@ impl Component { ...@@ -144,7 +148,8 @@ impl Component {
} }
pub fn service_name(&self) -> String { pub fn service_name(&self) -> String {
Slug::from_string(format!("{}|{}", self.namespace.name(), self.name)).to_string() let service_name = format!("{}_{}", self.namespace.name(), self.name);
Slug::slugify_unique(&service_name).to_string()
} }
pub fn path(&self) -> String { pub fn path(&self) -> String {
...@@ -159,6 +164,7 @@ impl Component { ...@@ -159,6 +164,7 @@ impl Component {
Endpoint { Endpoint {
component: self.clone(), component: self.clone(),
name: endpoint.into(), name: endpoint.into(),
is_static: self.is_static,
} }
} }
...@@ -204,6 +210,8 @@ pub struct Endpoint { ...@@ -204,6 +210,8 @@ pub struct Endpoint {
// todo - restrict alphabet // todo - restrict alphabet
/// Endpoint name /// Endpoint name
name: String, name: String,
is_static: bool,
} }
impl DistributedRuntimeProvider for Endpoint { impl DistributedRuntimeProvider for Endpoint {
...@@ -236,11 +244,19 @@ impl Endpoint { ...@@ -236,11 +244,19 @@ impl Endpoint {
} }
pub fn etcd_path_with_id(&self, lease_id: i64) -> String { pub fn etcd_path_with_id(&self, lease_id: i64) -> String {
format!("{}:{:x}", self.etcd_path(), lease_id) if self.is_static {
self.etcd_path()
} else {
format!("{}:{:x}", self.etcd_path(), lease_id)
}
} }
pub fn name_with_id(&self, lease_id: i64) -> String { pub fn name_with_id(&self, lease_id: i64) -> String {
format!("{}-{:x}", self.name, lease_id) if self.is_static {
self.name.clone()
} else {
format!("{}-{:x}", self.name, lease_id)
}
} }
pub fn subject(&self) -> String { pub fn subject(&self) -> String {
...@@ -261,7 +277,11 @@ impl Endpoint { ...@@ -261,7 +277,11 @@ impl Endpoint {
Req: Serialize + Send + Sync + 'static, Req: Serialize + Send + Sync + 'static,
Resp: for<'de> Deserialize<'de> + Send + Sync + 'static, Resp: for<'de> Deserialize<'de> + Send + Sync + 'static,
{ {
client::Client::new(self.clone()).await if self.is_static {
client::Client::new_static(self.clone()).await
} else {
client::Client::new_dynamic(self.clone()).await
}
} }
pub fn endpoint_builder(&self) -> endpoint::EndpointConfigBuilder { pub fn endpoint_builder(&self) -> endpoint::EndpointConfigBuilder {
...@@ -279,6 +299,8 @@ pub struct Namespace { ...@@ -279,6 +299,8 @@ pub struct Namespace {
#[validate()] #[validate()]
name: String, name: String,
is_static: bool,
} }
impl DistributedRuntimeProvider for Namespace { impl DistributedRuntimeProvider for Namespace {
...@@ -300,18 +322,20 @@ impl std::fmt::Display for Namespace { ...@@ -300,18 +322,20 @@ impl std::fmt::Display for Namespace {
} }
impl Namespace { impl Namespace {
pub(crate) fn new(runtime: DistributedRuntime, name: String) -> Result<Self> { pub(crate) fn new(runtime: DistributedRuntime, name: String, is_static: bool) -> Result<Self> {
Ok(NamespaceBuilder::default() Ok(NamespaceBuilder::default()
.runtime(runtime) .runtime(runtime)
.name(name) .name(name)
.is_static(is_static)
.build()?) .build()?)
} }
/// Create a [`Component`] in the namespace /// Create a [`Component`] in the namespace who's endpoints can be discovered with etcd
pub fn component(&self, name: impl Into<String>) -> Result<Component> { pub fn component(&self, name: impl Into<String>) -> Result<Component> {
Ok(ComponentBuilder::from_runtime(self.runtime.clone()) Ok(ComponentBuilder::from_runtime(self.runtime.clone())
.name(name) .name(name)
.namespace(self.clone()) .namespace(self.clone())
.is_static(self.is_static)
.build()?) .build()?)
} }
......
...@@ -52,8 +52,14 @@ enum EndpointEvent { ...@@ -52,8 +52,14 @@ enum EndpointEvent {
pub struct Client<T: Data, U: Data> { pub struct Client<T: Data, U: Data> {
endpoint: Endpoint, endpoint: Endpoint,
router: PushRouter<T, U>, router: PushRouter<T, U>,
watch_rx: tokio::sync::watch::Receiver<Vec<i64>>,
counter: Arc<AtomicU64>, counter: Arc<AtomicU64>,
endpoints: EndpointSource,
}
#[derive(Clone, Debug)]
enum EndpointSource {
Static,
Dynamic(tokio::sync::watch::Receiver<Vec<i64>>),
} }
impl<T, U> Client<T, U> impl<T, U> Client<T, U>
...@@ -61,17 +67,23 @@ where ...@@ -61,17 +67,23 @@ where
T: Data + Serialize, T: Data + Serialize,
U: Data + for<'de> Deserialize<'de>, U: Data + for<'de> Deserialize<'de>,
{ {
pub(crate) async fn new(endpoint: Endpoint) -> Result<Self> { // Client will only talk to a single static endpoint
let router = AddressedPushRouter::new( pub(crate) async fn new_static(endpoint: Endpoint) -> Result<Self> {
endpoint.component.drt.nats_client.client().clone(), Ok(Client {
endpoint.component.drt.tcp_server().await?, router: router(&endpoint).await?,
)?; endpoint,
counter: Arc::new(AtomicU64::new(0)),
endpoints: EndpointSource::Static,
})
}
// Client with auto-discover endpoints using etcd
pub(crate) async fn new_dynamic(endpoint: Endpoint) -> Result<Self> {
// create live endpoint watcher // create live endpoint watcher
let prefix_watcher = endpoint let Some(etcd_client) = &endpoint.component.drt.etcd_client else {
.component anyhow::bail!("Attempt to create a dynamic client on a static endpoint");
.drt };
.etcd_client let prefix_watcher = etcd_client
.kv_get_and_watch_prefix(endpoint.etcd_path()) .kv_get_and_watch_prefix(endpoint.etcd_path())
.await?; .await?;
...@@ -141,10 +153,10 @@ where ...@@ -141,10 +153,10 @@ where
}); });
Ok(Client { Ok(Client {
router: router(&endpoint).await?,
endpoint, endpoint,
router,
watch_rx,
counter: Arc::new(AtomicU64::new(0)), counter: Arc::new(AtomicU64::new(0)),
endpoints: EndpointSource::Dynamic(watch_rx),
}) })
} }
...@@ -158,31 +170,39 @@ where ...@@ -158,31 +170,39 @@ where
self.endpoint.etcd_path() self.endpoint.etcd_path()
} }
pub fn endpoint_ids(&self) -> &tokio::sync::watch::Receiver<Vec<i64>> { pub fn endpoint_ids(&self) -> Vec<i64> {
&self.watch_rx match &self.endpoints {
EndpointSource::Static => vec![0],
EndpointSource::Dynamic(watch_rx) => watch_rx.borrow().clone(),
}
} }
/// Wait for at least one [`Endpoint`] to be available /// Wait for at least one [`Endpoint`] to be available
pub async fn wait_for_endpoints(&self) -> Result<()> { pub async fn wait_for_endpoints(&self) -> Result<()> {
let mut rx = self.watch_rx.clone(); if let EndpointSource::Dynamic(mut rx) = self.endpoints.clone() {
// wait for there to be 1 or more endpoints // wait for there to be 1 or more endpoints
loop { loop {
if rx.borrow_and_update().is_empty() { if rx.borrow_and_update().is_empty() {
rx.changed().await?; rx.changed().await?;
} else { } else {
break; break;
}
} }
} }
Ok(()) Ok(())
} }
/// Is this component know at startup and not discovered via etcd?
pub fn is_static(&self) -> bool {
matches!(self.endpoints, EndpointSource::Static)
}
/// Issue a request to the next available endpoint in a round-robin fashion /// Issue a request to the next available endpoint in a round-robin fashion
pub async fn round_robin(&self, request: SingleIn<T>) -> Result<ManyOut<U>> { pub async fn round_robin(&self, request: SingleIn<T>) -> Result<ManyOut<U>> {
let counter = self.counter.fetch_add(1, Ordering::Relaxed); let counter = self.counter.fetch_add(1, Ordering::Relaxed);
let endpoint_id = { let endpoint_id = {
let endpoints = self.watch_rx.borrow(); let endpoints = self.endpoint_ids();
let count = endpoints.len(); let count = endpoints.len();
if count == 0 { if count == 0 {
return Err(error!( return Err(error!(
...@@ -203,7 +223,7 @@ where ...@@ -203,7 +223,7 @@ where
/// Issue a request to a random endpoint /// Issue a request to a random endpoint
pub async fn random(&self, request: SingleIn<T>) -> Result<ManyOut<U>> { pub async fn random(&self, request: SingleIn<T>) -> Result<ManyOut<U>> {
let endpoint_id = { let endpoint_id = {
let endpoints = self.watch_rx.borrow(); let endpoints = self.endpoint_ids();
let count = endpoints.len(); let count = endpoints.len();
if count == 0 { if count == 0 {
return Err(error!( return Err(error!(
...@@ -225,7 +245,7 @@ where ...@@ -225,7 +245,7 @@ where
/// Issue a request to a specific endpoint /// Issue a request to a specific endpoint
pub async fn direct(&self, request: SingleIn<T>, endpoint_id: i64) -> Result<ManyOut<U>> { pub async fn direct(&self, request: SingleIn<T>, endpoint_id: i64) -> Result<ManyOut<U>> {
let found = { let found = {
let endpoints = self.watch_rx.borrow(); let endpoints = self.endpoint_ids();
endpoints.contains(&endpoint_id) endpoints.contains(&endpoint_id)
}; };
...@@ -242,6 +262,21 @@ where ...@@ -242,6 +262,21 @@ where
self.router.generate(request).await self.router.generate(request).await
} }
pub async fn r#static(&self, request: SingleIn<T>) -> Result<ManyOut<U>> {
let subject = self.endpoint.subject();
tracing::debug!("static got subject: {subject}");
let request = request.map(|req| AddressedRequest::new(req, subject));
tracing::debug!("router generate");
self.router.generate(request).await
}
}
async fn router(endpoint: &Endpoint) -> Result<Arc<AddressedPushRouter>> {
AddressedPushRouter::new(
endpoint.component.drt.nats_client.client().clone(),
endpoint.component.drt.tcp_server().await?,
)
} }
#[async_trait] #[async_trait]
...@@ -251,6 +286,10 @@ where ...@@ -251,6 +286,10 @@ where
U: Data + for<'de> Deserialize<'de>, U: Data + for<'de> Deserialize<'de>,
{ {
async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> { async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
self.random(request).await tracing::debug!("Client::generate: {:?}", self.endpoints);
match &self.endpoints {
EndpointSource::Static => self.r#static(request).await,
EndpointSource::Dynamic(_) => self.random(request).await,
}
} }
} }
...@@ -56,11 +56,12 @@ impl EndpointConfigBuilder { ...@@ -56,11 +56,12 @@ impl EndpointConfigBuilder {
pub async fn start(self) -> Result<()> { pub async fn start(self) -> Result<()> {
let (endpoint, lease, handler, stats_handler) = self.build_internal()?.dissolve(); let (endpoint, lease, handler, stats_handler) = self.build_internal()?.dissolve();
let lease = lease.unwrap_or(endpoint.drt().primary_lease()); let lease = lease.or(endpoint.drt().primary_lease());
let lease_id = lease.as_ref().map(|l| l.id()).unwrap_or(0);
tracing::debug!( tracing::debug!(
"Starting endpoint: {}", "Starting endpoint: {}",
endpoint.etcd_path_with_id(lease.id()) endpoint.etcd_path_with_id(lease_id)
); );
let service_name = endpoint.component.service_name(); let service_name = endpoint.component.service_name();
...@@ -89,16 +90,18 @@ impl EndpointConfigBuilder { ...@@ -89,16 +90,18 @@ impl EndpointConfigBuilder {
handler_map handler_map
.lock() .lock()
.unwrap() .unwrap()
.insert(endpoint.subject_to(lease.id()), stats_handler); .insert(endpoint.subject_to(lease_id), stats_handler);
} }
// creates an endpoint for the service // creates an endpoint for the service
let service_endpoint = group let service_endpoint = group
.endpoint(&endpoint.name_with_id(lease.id())) .endpoint(&endpoint.name_with_id(lease_id))
.await .await
.map_err(|e| anyhow::anyhow!("Failed to start endpoint: {e}"))?; .map_err(|e| anyhow::anyhow!("Failed to start endpoint: {e}"))?;
let cancel_token = lease.child_token(); let cancel_token = lease
.map(|l| l.child_token())
.unwrap_or_else(|| endpoint.drt().child_token());
let push_endpoint = PushEndpoint::builder() let push_endpoint = PushEndpoint::builder()
.service_handler(handler) .service_handler(handler)
...@@ -116,28 +119,22 @@ impl EndpointConfigBuilder { ...@@ -116,28 +119,22 @@ impl EndpointConfigBuilder {
component: endpoint.component.name.clone(), component: endpoint.component.name.clone(),
endpoint: endpoint.name.clone(), endpoint: endpoint.name.clone(),
namespace: endpoint.component.namespace.name.clone(), namespace: endpoint.component.namespace.name.clone(),
lease_id: lease.id(), lease_id,
transport: TransportType::NatsTcp(endpoint.subject_to(lease.id())), transport: TransportType::NatsTcp(endpoint.subject_to(lease_id)),
}; };
let info = serde_json::to_vec_pretty(&info)?; let info = serde_json::to_vec_pretty(&info)?;
if let Err(e) = endpoint if let Some(etcd_client) = &endpoint.component.drt.etcd_client {
.component if let Err(e) = etcd_client
.drt .kv_create(endpoint.etcd_path_with_id(lease_id), info, Some(lease_id))
.etcd_client .await
.kv_create( {
endpoint.etcd_path_with_id(lease.id()), tracing::error!("Failed to register discoverable service: {:?}", e);
info, cancel_token.cancel();
Some(lease.id()), return Err(error!("Failed to register discoverable service"));
) }
.await
{
tracing::error!("Failed to register discoverable service: {:?}", e);
cancel_token.cancel();
return Err(error!("Failed to register discoverable service"));
} }
task.await??; task.await??;
Ok(()) Ok(())
......
...@@ -69,7 +69,7 @@ impl ServiceConfigBuilder { ...@@ -69,7 +69,7 @@ impl ServiceConfigBuilder {
let builder = component.drt.nats_client.client().service_builder(); let builder = component.drt.nats_client.client().service_builder();
tracing::debug!("Starting service: {}", service_name); tracing::debug!("Starting service: {}", service_name);
let service = builder let service_builder = builder
.description(description) .description(description)
.stats_handler(move |name, stats| { .stats_handler(move |name, stats| {
log::trace!("stats_handler: {name}, {stats:?}"); log::trace!("stats_handler: {name}, {stats:?}");
...@@ -78,11 +78,15 @@ impl ServiceConfigBuilder { ...@@ -78,11 +78,15 @@ impl ServiceConfigBuilder {
Some(handler) => handler(stats), Some(handler) => handler(stats),
None => serde_json::Value::Null, None => serde_json::Value::Null,
} }
}) });
tracing::debug!("Got builder");
let service = service_builder
.start(service_name.clone(), version) .start(service_name.clone(), version)
.await .await
.map_err(|e| anyhow::anyhow!("Failed to start service: {e}"))?; .map_err(|e| anyhow::anyhow!("Failed to start service: {e}"))?;
tracing::debug!("Service started TEMP");
// new copy of service_name as the previous one is moved into the task above // new copy of service_name as the previous one is moved into the task above
let service_name = component.service_name(); let service_name = component.service_name();
...@@ -97,6 +101,7 @@ impl ServiceConfigBuilder { ...@@ -97,6 +101,7 @@ impl ServiceConfigBuilder {
// drop the guard to unlock the mutex // drop the guard to unlock the mutex
drop(guard); drop(guard);
tracing::debug!("create done");
Ok(component) Ok(component)
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment