Unverified Commit 0ce3461a authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

feat: Add runtime.endpoint() method to eliminate namespace chaining (#6386)


Signed-off-by: default avatartzulingk@nvidia.com <tzulingk@nvidia.com>
parent 6f4b33f7
......@@ -41,9 +41,7 @@ async def worker(runtime: DistributedRuntime):
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
component = runtime.namespace("dynamo").component("backend")
endpoint = component.endpoint("generate")
endpoint = runtime.endpoint("dynamo.backend.generate")
await endpoint.serve_endpoint(RequestHandler().generate)
......
......@@ -150,7 +150,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<DistributedRuntime>()?;
m.add_class::<CancellationToken>()?;
m.add_class::<Namespace>()?;
m.add_class::<Component>()?;
m.add_class::<Endpoint>()?;
m.add_class::<ModelCardInstanceId>()?;
......@@ -461,13 +460,6 @@ struct CancellationToken {
inner: rs::CancellationToken,
}
#[pyclass]
#[derive(Clone)]
struct Namespace {
inner: rs::component::Namespace,
event_loop: PyObject,
}
#[pyclass]
#[derive(Clone)]
struct Component {
......@@ -648,9 +640,34 @@ impl DistributedRuntime {
})
}
fn namespace(&self, name: String) -> PyResult<Namespace> {
Ok(Namespace {
inner: self.inner.namespace(name).map_err(to_pyerr)?,
/// Get an endpoint directly by path (e.g., "namespace.component.endpoint" or "dyn://...").
fn endpoint(&self, path: String) -> PyResult<Endpoint> {
let trimmed_path = path.trim_start_matches("dyn://");
let parts: Vec<&str> = trimmed_path.split('.').collect();
if parts.len() != 3 {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Invalid endpoint path '{}'. Expected format: 'namespace.component.endpoint' or 'dyn://namespace.component.endpoint'",
path
)));
}
let namespace_name = parts[0];
let component_name = parts[1];
let endpoint_name = parts[2];
// Get endpoint using existing chain
let namespace = self
.inner
.namespace(namespace_name.to_string())
.map_err(to_pyerr)?;
let component = namespace
.component(component_name.to_string())
.map_err(to_pyerr)?;
let endpoint = component.endpoint(endpoint_name.to_string());
Ok(Endpoint {
inner: endpoint,
event_loop: self.event_loop.clone(),
})
}
......@@ -910,16 +927,16 @@ impl Endpoint {
Ok(())
})
}
}
#[pymethods]
impl Namespace {
fn component(&self, name: String) -> PyResult<Component> {
let inner = self.inner.component(name).map_err(to_pyerr)?;
Ok(Component {
inner,
/// Get the parent Component.
///
/// Note: To avoid duplicate metrics registries, reuse the returned Component for
/// multiple endpoints: `component.endpoint("ep1")`, `component.endpoint("ep2")`.
fn component(&self) -> Component {
Component {
inner: self.inner.component().clone(),
event_loop: self.event_loop.clone(),
})
}
}
}
......
......@@ -58,9 +58,24 @@ class DistributedRuntime:
"""
...
def namespace(self, name: str) -> Namespace:
def endpoint(self, path: str) -> Endpoint:
"""
Create a `Namespace` object
Get an endpoint directly by path.
Args:
path: Endpoint path in format 'namespace.component.endpoint'
or 'dyn://namespace.component.endpoint'
Returns:
Endpoint: The requested endpoint
Raises:
ValueError: If path format is invalid (not 3 parts separated by dots)
Exception: If namespace or component creation fails
Example:
endpoint = runtime.endpoint("demo.backend.generate")
endpoint = runtime.endpoint("dyn://demo.backend.generate")
"""
...
......@@ -116,19 +131,6 @@ class CancellationToken:
...
class Namespace:
"""
A namespace is a collection of components
"""
...
def component(self, name: str) -> Component:
"""
Create a `Component` object
"""
...
class Component:
"""
A component is a collection of endpoints
......@@ -208,6 +210,24 @@ class Endpoint:
"""
...
def component(self) -> Component:
"""
Get the parent Component that this endpoint belongs to.
Returns:
Component: The parent component
Note:
To avoid duplicate metrics registries, reuse the returned Component for
multiple endpoints: component.endpoint("ep1"), component.endpoint("ep2")
Example:
endpoint = runtime.endpoint("demo.backend.generate")
component = endpoint.component()
health_endpoint = component.endpoint("health") # Reuse component
"""
...
class Client:
"""
......
......@@ -15,7 +15,7 @@ from dynamo._core import Component as Component
from dynamo._core import Context as Context
from dynamo._core import DistributedRuntime as DistributedRuntime
from dynamo._core import Endpoint as Endpoint
from dynamo._core import Namespace as Namespace
from dynamo._core import ModelDeploymentCard as ModelDeploymentCard
def dynamo_worker(enable_nats: bool = True):
......
......@@ -128,8 +128,7 @@ async def server(runtime, namespace):
async def init_server():
"""Initialize the test server component and serve the generate endpoint"""
component = runtime.namespace(namespace).component("backend")
endpoint = component.endpoint("generate")
endpoint = runtime.endpoint(f"{namespace}.backend.generate")
print("Started test server instance")
# Serve the endpoint - this will block until shutdown
......@@ -156,7 +155,7 @@ async def server(runtime, namespace):
async def client(runtime, namespace):
"""Create a client connected to the test server"""
# Create client
endpoint = runtime.namespace(namespace).component("backend").endpoint("generate")
endpoint = runtime.endpoint(f"{namespace}.backend.generate")
client = await endpoint.client()
await client.wait_for_instances()
......
......@@ -16,9 +16,7 @@ TEST_END_TO_END = os.environ.get("TEST_END_TO_END", 0)
@pytest.mark.asyncio
async def test_register(runtime: DistributedRuntime):
component = runtime.namespace("test").component("tensor")
endpoint = component.endpoint("generate")
endpoint = runtime.endpoint("test.tensor.generate")
model_config = {
"name": "tensor",
......
......@@ -15,9 +15,7 @@ from dynamo.runtime import DistributedRuntime, dynamo_worker
@dynamo_worker()
async def echo_tensor_worker(runtime: DistributedRuntime):
component = runtime.namespace("tensor").component("echo")
endpoint = component.endpoint("generate")
endpoint = runtime.endpoint("tensor.echo.generate")
triton_model_config = mc.ModelConfig()
triton_model_config.name = "echo"
......
......@@ -30,14 +30,10 @@ pytestmark = [
def mock_runtime():
"""Create a mock DistributedRuntime."""
runtime = MagicMock()
namespace_mock = MagicMock()
component_mock = MagicMock()
endpoint_mock = MagicMock()
client_mock = AsyncMock()
runtime.namespace.return_value = namespace_mock
namespace_mock.component.return_value = component_mock
component_mock.endpoint.return_value = endpoint_mock
runtime.endpoint.return_value = endpoint_mock
endpoint_mock.client = AsyncMock(return_value=client_mock)
client_mock.wait_for_instances = AsyncMock()
......@@ -82,21 +78,17 @@ async def test_send_scale_request_success(mock_runtime):
assert response.current_replicas["decode"] == 5
# Verify lazy init happened
assert client._client is not None
runtime.namespace.assert_called_once_with("central-ns")
runtime.endpoint.assert_called_once_with("central-ns.Planner.scale_request")
@pytest.mark.asyncio
async def test_send_scale_request_error():
"""Test scale request error handling."""
runtime = MagicMock()
namespace_mock = MagicMock()
component_mock = MagicMock()
endpoint_mock = MagicMock()
client_mock = AsyncMock()
runtime.namespace.return_value = namespace_mock
namespace_mock.component.return_value = component_mock
component_mock.endpoint.return_value = endpoint_mock
runtime.endpoint.return_value = endpoint_mock
endpoint_mock.client = AsyncMock(return_value=client_mock)
client_mock.wait_for_instances = AsyncMock()
......@@ -132,14 +124,10 @@ async def test_send_scale_request_error():
async def test_send_scale_request_no_response():
"""Test scale request when no response is received."""
runtime = MagicMock()
namespace_mock = MagicMock()
component_mock = MagicMock()
endpoint_mock = MagicMock()
client_mock = AsyncMock()
runtime.namespace.return_value = namespace_mock
namespace_mock.component.return_value = component_mock
component_mock.endpoint.return_value = endpoint_mock
runtime.endpoint.return_value = endpoint_mock
endpoint_mock.client = AsyncMock(return_value=client_mock)
client_mock.wait_for_instances = AsyncMock()
......
......@@ -1389,9 +1389,9 @@ def _test_router_indexers_sync(
# Create first runtime and endpoint for router 1
logger.info("Creating first KV router with its own runtime")
runtime1 = get_runtime(store_backend, request_plane)
namespace1 = runtime1.namespace(engine_workers.namespace)
component1 = namespace1.component(engine_workers.component_name)
endpoint1 = component1.endpoint("generate")
endpoint1 = runtime1.endpoint(
f"{engine_workers.namespace}.{engine_workers.component_name}.generate"
)
kv_router1 = KvRouter(
endpoint=endpoint1,
......@@ -1442,9 +1442,9 @@ def _test_router_indexers_sync(
# Create second runtime and endpoint for router 2
logger.info("Creating second KV router with its own runtime")
runtime2 = get_runtime(store_backend, request_plane)
namespace2 = runtime2.namespace(engine_workers.namespace)
component2 = namespace2.component(engine_workers.component_name)
endpoint2 = component2.endpoint("generate")
endpoint2 = runtime2.endpoint(
f"{engine_workers.namespace}.{engine_workers.component_name}.generate"
)
kv_router2 = KvRouter(
endpoint=endpoint2,
......
......@@ -507,9 +507,9 @@ def test_kv_router_bindings(
# Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane)
namespace = runtime.namespace(mockers.namespace)
component = namespace.component(mockers.component_name)
endpoint = component.endpoint("generate")
endpoint = runtime.endpoint(
f"{mockers.namespace}.{mockers.component_name}.generate"
)
# Run Python router bindings test
_test_python_router_bindings(
......@@ -697,9 +697,7 @@ def test_router_decisions(
# Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane)
# Use the namespace from the mockers
namespace = runtime.namespace(mockers.namespace)
component = namespace.component("mocker")
endpoint = component.endpoint("generate")
endpoint = runtime.endpoint(f"{mockers.namespace}.mocker.generate")
_test_router_decisions(
mockers,
......
......@@ -416,9 +416,7 @@ def test_router_decisions_sglang_multiple_workers(
logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}")
runtime = get_runtime(request_plane=request_plane)
namespace = runtime.namespace(sglang_workers.namespace)
component = namespace.component("backend")
endpoint = component.endpoint("generate")
endpoint = runtime.endpoint(f"{sglang_workers.namespace}.backend.generate")
_test_router_decisions(
sglang_workers,
......@@ -465,9 +463,7 @@ def test_router_decisions_sglang_dp(
# Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane)
# Use the namespace from the SGLang workers
namespace = runtime.namespace(sglang_workers.namespace)
component = namespace.component("backend") # endpoint is backend.generate
endpoint = component.endpoint("generate")
endpoint = runtime.endpoint(f"{sglang_workers.namespace}.backend.generate")
_test_router_decisions(
sglang_workers, endpoint, MODEL_NAME, request, test_dp_rank=True
......
......@@ -408,9 +408,7 @@ def test_router_decisions_trtllm_attention_dp(
# Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane)
# Use the namespace from the vLLM workers
namespace = runtime.namespace(trtllm_workers.namespace)
component = namespace.component("tensorrt_llm")
endpoint = component.endpoint("generate")
endpoint = runtime.endpoint(f"{trtllm_workers.namespace}.tensorrt_llm.generate")
_test_router_decisions(
trtllm_workers,
......@@ -457,9 +455,7 @@ def test_router_decisions_trtllm_multiple_workers(
logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}")
runtime = get_runtime(request_plane=request_plane)
namespace = runtime.namespace(trtllm_workers.namespace)
component = namespace.component("tensorrt_llm")
endpoint = component.endpoint("generate")
endpoint = runtime.endpoint(f"{trtllm_workers.namespace}.tensorrt_llm.generate")
_test_router_decisions(
trtllm_workers,
......
......@@ -434,9 +434,7 @@ def test_router_decisions_vllm_multiple_workers(
# Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane)
namespace = runtime.namespace(vllm_workers.namespace)
component = namespace.component("backend")
endpoint = component.endpoint("generate")
endpoint = runtime.endpoint(f"{vllm_workers.namespace}.backend.generate")
_test_router_decisions(
vllm_workers,
......@@ -483,9 +481,9 @@ def test_router_decisions_vllm_dp(
# Get runtime and create endpoint
runtime = get_runtime(request_plane=request_plane)
# Use the namespace from the vLLM workers
namespace = runtime.namespace(vllm_workers.namespace)
component = namespace.component("backend") # endpoint is backend.generate
endpoint = component.endpoint("generate")
endpoint = runtime.endpoint(
f"{vllm_workers.namespace}.backend.generate"
) # endpoint is backend.generate
_test_router_decisions(
vllm_workers, endpoint, MODEL_NAME, request, test_dp_rank=True
......
......@@ -43,8 +43,7 @@ async def main(runtime: DistributedRuntime):
"""Main worker function for template verification."""
# Create service
component = runtime.namespace("test").component("backend")
endpoint = component.endpoint("generate")
endpoint = runtime.endpoint("test.backend.generate")
# Use the existing custom template from fixtures
template_path = Path(SERVE_TEST_DIR) / "fixtures" / "custom_template.jinja"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment