"deploy/operator/config/prometheus/kustomization.yaml" did not exist on "61af664b2cea35bb623a0f402a554c711084d08e"
Unverified Commit 80cac7c1 authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

feat: Remove Component from public (#6403)


Signed-off-by: default avatartzulingk@nvidia.com <tzulingk@nvidia.com>
parent cb55766c
......@@ -336,7 +336,7 @@ async def init_llm_worker(
endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
component = endpoint.component()
if shutdown_endpoints is not None:
shutdown_endpoints[:] = [endpoint]
......@@ -419,7 +419,6 @@ async def init_llm_worker(
# publisher will be set later if publishing is enabled.
handler_config = RequestHandlerConfig(
component=component,
engine=engine,
default_sampling_params=default_sampling_params,
publisher=None,
......@@ -456,7 +455,6 @@ async def init_llm_worker(
if config.publish_events_and_metrics:
# Initialize and pass in the publisher to the request handler to
# publish events and metrics.
kv_listener = endpoint.component()
# Use model as fallback if served_model_name is not provided
model_name_for_metrics = config.served_model_name or config.model
metrics_labels = [
......@@ -476,7 +474,7 @@ async def init_llm_worker(
if consolidator_output_endpoint:
# Use the connect endpoint directly (already provided by get_consolidator_endpoints)
consolidator_publisher = KvEventPublisher(
component,
endpoint=endpoint,
kv_block_size=config.kv_block_size,
zmq_endpoint=consolidator_output_connect_endpoint,
zmq_topic="",
......@@ -487,9 +485,8 @@ async def init_llm_worker(
)
async with get_publisher(
component,
endpoint,
engine,
kv_listener,
int(endpoint.connection_id()),
config.kv_block_size,
metrics_labels,
......
......@@ -91,7 +91,7 @@ async def init_video_diffusion_worker(
endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
component = endpoint.component()
if shutdown_endpoints is not None:
shutdown_endpoints[:] = [endpoint]
......@@ -100,7 +100,7 @@ async def init_video_diffusion_worker(
await engine.initialize()
# Create the request handler
handler = VideoGenerationHandler(component, engine, diffusion_config)
handler = VideoGenerationHandler(engine, diffusion_config)
# Register the model with Dynamo's discovery system
model_name = config.served_model_name or config.model
......
......@@ -268,7 +268,6 @@ class BaseWorkerHandler(ABC):
def __init__(
self,
runtime,
component,
engine,
default_sampling_params,
model_max_len: int | None = None,
......@@ -280,7 +279,6 @@ class BaseWorkerHandler(ABC):
enable_frontend_decoding: bool = False,
):
self.runtime = runtime
self.component = component
self.engine_client = engine
self.default_sampling_params = default_sampling_params
self.kv_publishers: list[KvEventPublisher] | None = None
......@@ -1233,7 +1231,6 @@ class DecodeWorkerHandler(BaseWorkerHandler):
def __init__(
self,
runtime,
component,
engine,
default_sampling_params,
model_max_len: int | None = None,
......@@ -1246,7 +1243,6 @@ class DecodeWorkerHandler(BaseWorkerHandler):
):
super().__init__(
runtime,
component,
engine,
default_sampling_params,
model_max_len,
......@@ -1443,7 +1439,6 @@ class PrefillWorkerHandler(BaseWorkerHandler):
def __init__(
self,
runtime,
component,
engine,
default_sampling_params,
model_max_len: int | None = None,
......@@ -1456,7 +1451,6 @@ class PrefillWorkerHandler(BaseWorkerHandler):
):
super().__init__(
runtime,
component,
engine,
default_sampling_params,
model_max_len,
......
......@@ -11,6 +11,7 @@ from typing import Optional
import uvloop
from prometheus_client import REGISTRY, CollectorRegistry, multiprocess
from vllm.config import VllmConfig
from vllm.distributed.kv_events import ZmqEventPublisher
from vllm.entrypoints.cli.serve import run_headless
from vllm.usage.usage_lib import UsageContext
......@@ -47,7 +48,7 @@ except ImportError:
MediaFetcher = None
MEDIA_DECODER_AVAILABLE = False
from dynamo.runtime import DistributedRuntime
from dynamo.runtime import DistributedRuntime, Endpoint
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.vllm.worker_factory import WorkerFactory
......@@ -295,9 +296,8 @@ def setup_metrics_collection(config: Config, generate_endpoint, logger):
def setup_kv_event_publisher(
config: Config,
component,
generate_endpoint,
vllm_config,
generate_endpoint: Endpoint,
vllm_config: VllmConfig,
consolidator_enabled: bool = False,
consolidator_port: Optional[int] = 5558,
) -> Optional[KvEventPublisher]:
......@@ -306,7 +306,6 @@ def setup_kv_event_publisher(
Creates one publisher per dp_rank since each dp_rank publishes to a different port.
Args:
config: Worker configuration
component: Component for runtime integration
generate_endpoint: Endpoint for worker ID
vllm_config: vLLM configuration
consolidator_enabled: If True, subscribe to kv eventconsolidator's ZMQ endpoint
......@@ -355,7 +354,7 @@ def setup_kv_event_publisher(
)
kv_publisher = KvEventPublisher(
component=component,
endpoint=generate_endpoint,
kv_block_size=vllm_config.cache_config.block_size,
zmq_endpoint=zmq_endpoint,
zmq_topic="",
......@@ -573,8 +572,9 @@ async def init_prefill(
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
component = generate_endpoint.component()
clear_endpoint = component.endpoint("clear_kv_blocks")
clear_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.clear_kv_blocks"
)
# Use pre-created engine if provided (checkpoint mode), otherwise create new
if pre_created_engine is not None:
......@@ -596,7 +596,6 @@ async def init_prefill(
handler = PrefillWorkerHandler(
runtime,
component,
engine_client,
default_sampling_params,
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
......@@ -627,7 +626,6 @@ async def init_prefill(
# If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
kv_publishers = setup_kv_event_publisher(
config,
component,
generate_endpoint,
vllm_config,
consolidator_enabled=consolidator_enabled,
......@@ -717,8 +715,9 @@ async def init(
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
component = generate_endpoint.component()
clear_endpoint = component.endpoint("clear_kv_blocks")
clear_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.clear_kv_blocks"
)
shutdown_endpoints[:] = [
generate_endpoint,
......@@ -727,9 +726,15 @@ async def init(
lora_enabled = config.engine_args.enable_lora
if lora_enabled:
load_lora_endpoint = component.endpoint("load_lora")
unload_lora_endpoint = component.endpoint("unload_lora")
list_loras_endpoint = component.endpoint("list_loras")
load_lora_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.load_lora"
)
unload_lora_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.unload_lora"
)
list_loras_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.list_loras"
)
shutdown_endpoints.extend(
[
......@@ -739,8 +744,6 @@ async def init(
]
)
model_name = config.served_model_name or config.model
# Use pre-created engine if provided (checkpoint mode), otherwise create new
if pre_created_engine is not None:
(
......@@ -752,19 +755,17 @@ async def init(
) = pre_created_engine
# Factory is created after unpack so component_gauges is available
factory = StatLoggerFactory(
component,
endpoint=generate_endpoint,
component_gauges=component_gauges,
dp_rank=config.engine_args.data_parallel_rank or 0,
metrics_labels=[("model", model_name)],
)
else:
# Factory is created without component_gauges; setup_vllm_engine() will
# create the gauges after setup_multiprocess_prometheus() and set them
# on the factory before vLLM calls create_stat_logger().
factory = StatLoggerFactory(
component,
endpoint=generate_endpoint,
dp_rank=config.engine_args.data_parallel_rank or 0,
metrics_labels=[("model", model_name)],
)
(
engine_client,
......@@ -780,7 +781,6 @@ async def init(
handler = DecodeWorkerHandler(
runtime,
component,
engine_client,
default_sampling_params,
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
......@@ -811,7 +811,6 @@ async def init(
# If kv event consolidator is enabled, publisher will subscribe to kv event consolidator's output
kv_publishers = setup_kv_event_publisher(
config,
component,
generate_endpoint,
vllm_config,
consolidator_enabled=consolidator_enabled,
......@@ -957,7 +956,7 @@ async def init_omni(
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
# Initialize media filesystem for storing generated images/videos
......@@ -968,7 +967,6 @@ async def init_omni(
# Initialize unified OmniHandler
handler = OmniHandler(
runtime=runtime,
component=component,
config=config,
default_sampling_params={},
shutdown_event=shutdown_event,
......
......@@ -20,7 +20,7 @@ from dynamo.common.multimodal.embedding_transfer import (
LocalEmbeddingReceiver,
NixlPersistentEmbeddingReceiver,
)
from dynamo.runtime import Client, Component, DistributedRuntime
from dynamo.runtime import Client, DistributedRuntime
from ..args import Config
from ..handlers import BaseWorkerHandler, build_sampling_params
......@@ -44,7 +44,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
def __init__(
self,
runtime,
component: Component,
engine_client: AsyncLLM,
config: Config,
encode_worker_client: Client | None = None,
......@@ -60,7 +59,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
# Call BaseWorkerHandler.__init__ with proper parameters
super().__init__(
runtime,
component,
engine_client,
default_sampling_params,
enable_multimodal=config.enable_multimodal,
......
......@@ -22,7 +22,6 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
def __init__(
self,
runtime,
component,
engine_client,
config: Config,
shutdown_event=None,
......@@ -36,7 +35,6 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
# Call BaseWorkerHandler.__init__ with proper parameters
super().__init__(
runtime,
component,
engine_client,
default_sampling_params,
enable_multimodal=config.enable_multimodal,
......
......@@ -68,7 +68,6 @@ class OmniHandler(BaseOmniHandler):
def __init__(
self,
runtime,
component,
config,
default_sampling_params: Dict[str, Any],
shutdown_event: asyncio.Event | None = None,
......@@ -88,7 +87,6 @@ class OmniHandler(BaseOmniHandler):
"""
super().__init__(
runtime=runtime,
component=component,
config=config,
default_sampling_params=default_sampling_params,
shutdown_event=shutdown_event,
......
......@@ -3,7 +3,7 @@
import asyncio
import logging
from typing import List, Optional, Tuple
from typing import Optional
from prometheus_client import CollectorRegistry
from vllm.config import VllmConfig
......@@ -12,7 +12,7 @@ from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from dynamo.common.utils.prometheus import LLMBackendMetrics
from dynamo.llm import WorkerMetricsPublisher
from dynamo.runtime import Component
from dynamo.runtime import Endpoint
# Create a dedicated registry for dynamo_component metrics
# This ensures these metrics are isolated and can be exposed via their own callback
......@@ -42,15 +42,14 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
def __init__(
self,
component: Component,
dp_rank: int,
component_gauges: LLMBackendMetrics,
metrics_labels: Optional[List[Tuple[str, str]]] = None,
endpoint: Endpoint,
dp_rank: int = 0,
component_gauges: Optional[LLMBackendMetrics] = None,
) -> None:
self.inner = WorkerMetricsPublisher()
self._component = component
self._endpoint = endpoint
self.dp_rank = dp_rank
self.component_gauges = component_gauges
self.component_gauges = component_gauges or LLMBackendMetrics()
self.num_gpu_block = 1
# Schedule async endpoint creation
self._endpoint_task = asyncio.create_task(self._create_endpoint())
......@@ -58,7 +57,7 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
async def _create_endpoint(self) -> None:
"""Create the NATS endpoint asynchronously."""
try:
await self.inner.create_endpoint(self._component)
await self.inner.create_endpoint(self._endpoint)
logging.debug("vLLM metrics publisher endpoint created")
except Exception:
logging.exception("Failed to create vLLM metrics publisher endpoint")
......@@ -105,16 +104,14 @@ class StatLoggerFactory:
def __init__(
self,
component: Component,
endpoint: Endpoint,
component_gauges: Optional[LLMBackendMetrics] = None,
dp_rank: int = 0,
metrics_labels: Optional[List[Tuple[str, str]]] = None,
) -> None:
self.component = component
self.endpoint = endpoint
self.component_gauges = component_gauges
self.created_logger: Optional[DynamoStatLoggerPublisher] = None
self.dp_rank = dp_rank
self.metrics_labels = metrics_labels or []
def create_stat_logger(self, dp_rank: int) -> StatLoggerBase:
if self.dp_rank != dp_rank:
......@@ -124,11 +121,11 @@ class StatLoggerFactory:
assert (
self.component_gauges is not None
), "component_gauges must be set before creating stat loggers"
logger = DynamoStatLoggerPublisher(
self.component,
dp_rank,
endpoint=self.endpoint,
dp_rank=dp_rank,
component_gauges=self.component_gauges,
metrics_labels=self.metrics_labels,
)
self.created_logger = logger
......
......@@ -62,7 +62,6 @@ def _make_handler(
with patch.object(mod.BaseWorkerHandler, "__init__", return_value=None):
return mod.MultimodalPDWorkerHandler(
runtime=MagicMock(),
component=MagicMock(),
engine_client=MagicMock(),
config=config,
encode_worker_client=encode_worker_client,
......
......@@ -105,6 +105,7 @@ class TestCreate:
Mock(),
Mock(),
"/tmp/prometheus",
Mock(),
)
await factory.create(
......
......@@ -100,14 +100,22 @@ class WorkerFactory:
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
component = generate_endpoint.component()
clear_endpoint = component.endpoint("clear_kv_blocks")
clear_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.clear_kv_blocks"
)
shutdown_endpoints[:] = [generate_endpoint, clear_endpoint]
lora_enabled = config.engine_args.enable_lora
if lora_enabled:
load_lora_endpoint = component.endpoint("load_lora")
unload_lora_endpoint = component.endpoint("unload_lora")
list_loras_endpoint = component.endpoint("list_loras")
load_lora_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.load_lora"
)
unload_lora_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.unload_lora"
)
list_loras_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.list_loras"
)
shutdown_endpoints.extend(
[load_lora_endpoint, unload_lora_endpoint, list_loras_endpoint]
)
......@@ -152,7 +160,6 @@ class WorkerFactory:
if config.multimodal_decode_worker:
handler = MultimodalDecodeWorkerHandler(
runtime,
component,
engine_client,
config,
shutdown_event,
......@@ -161,7 +168,6 @@ class WorkerFactory:
else:
handler = MultimodalPDWorkerHandler(
runtime,
component,
engine_client,
config,
encode_worker_client,
......@@ -175,7 +181,7 @@ class WorkerFactory:
# Set up KV event publisher for prefix caching if enabled
kv_publisher = self.setup_kv_event_publisher(
config, component, generate_endpoint, vllm_config
config, generate_endpoint, vllm_config
)
if kv_publisher:
handler.kv_publisher = kv_publisher
......
......@@ -22,7 +22,7 @@ from vllm.v1.metrics.loggers import StatLoggerBase
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from dynamo.llm import WorkerMetricsPublisher
from dynamo.runtime import Component
from dynamo.runtime import Endpoint
class NullStatLogger(StatLoggerBase):
......@@ -48,11 +48,11 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
def __init__(
self,
component: Component,
endpoint: Endpoint,
dp_rank: int,
) -> None:
self.inner = WorkerMetricsPublisher()
self._component = component
self._endpoint = endpoint
self.dp_rank = dp_rank
self.num_gpu_block = 1
# Schedule async endpoint creation
......@@ -61,7 +61,7 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
async def _create_endpoint(self) -> None:
"""Create the NATS endpoint asynchronously."""
try:
await self.inner.create_endpoint(self._component)
await self.inner.create_endpoint(self._endpoint)
logging.debug("Multimodal metrics publisher endpoint created")
except Exception:
logging.exception("Failed to create multimodal metrics publisher endpoint")
......@@ -94,11 +94,11 @@ class StatLoggerFactory:
def __init__(
self,
component: Component,
endpoint: Endpoint,
dp_rank: int = 0,
metrics_labels: Optional[List[Tuple[str, str]]] = None,
) -> None:
self.component = component
self.endpoint = endpoint
self.created_logger: Optional[DynamoStatLoggerPublisher] = None
self.dp_rank = dp_rank
self.metrics_labels = metrics_labels or []
......@@ -106,7 +106,7 @@ class StatLoggerFactory:
def create_stat_logger(self, dp_rank: int) -> StatLoggerBase:
if self.dp_rank != dp_rank:
return NullStatLogger()
logger = DynamoStatLoggerPublisher(self.component, dp_rank)
logger = DynamoStatLoggerPublisher(self.endpoint, dp_rank)
self.created_logger = logger
return logger
......
......@@ -24,7 +24,7 @@ from vllm.v1.engine.async_llm import AsyncLLM
import dynamo.nixl_connect as connect
from dynamo.llm import KvEventPublisher
from dynamo.runtime import Component, DistributedRuntime, Endpoint, dynamo_worker
from dynamo.runtime import DistributedRuntime, Endpoint, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
......@@ -104,7 +104,6 @@ class VllmBaseWorker:
def __init__(
self,
args: argparse.Namespace,
component: Component,
endpoint: Endpoint,
config: Config,
):
......@@ -113,15 +112,15 @@ class VllmBaseWorker:
self.downstream_endpoint = args.downstream_endpoint
self.engine_args = config.engine_args
self.config = config
self.setup_vllm_engine(component, endpoint)
self.setup_vllm_engine(endpoint)
async def async_init(self, runtime: DistributedRuntime):
pass
def setup_vllm_engine(self, component: Component, endpoint: Endpoint):
def setup_vllm_engine(self, endpoint: Endpoint):
"""Initialize the vLLM engine.
This method sets up the vLLM engine client, and configures the dynamo-aware KV
event publisher and metrics stats logger based on component and endpoint.
event publisher and metrics stats logger based on endpoint.
"""
os.environ["VLLM_NO_USAGE_STATS"] = "1" # Avoid internal HTTP requests
......@@ -138,9 +137,8 @@ class VllmBaseWorker:
# Create vLLM engine with metrics logger and KV event publisher attached
self.stats_logger = StatLoggerFactory(
component,
self.engine_args.data_parallel_rank or 0,
metrics_labels=[("model", self.config.model)],
endpoint=endpoint,
dp_rank=self.engine_args.data_parallel_rank or 0,
)
self.engine_client = AsyncLLM.from_vllm_config(
vllm_config=vllm_config,
......@@ -164,7 +162,7 @@ class VllmBaseWorker:
).replace("*", "127.0.0.1")
self.kv_publisher = KvEventPublisher(
component=component,
endpoint=endpoint,
kv_block_size=vllm_config.cache_config.block_size,
zmq_endpoint=zmq_endpoint,
)
......@@ -435,15 +433,14 @@ async def init(runtime: DistributedRuntime, args: argparse.Namespace, config: Co
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
component = generate_endpoint.component()
clear_endpoint = component.endpoint("clear_kv_blocks")
clear_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.clear_kv_blocks"
)
if args.worker_type in ["prefill", "encode_prefill"]:
handler: VllmBaseWorker = VllmPDWorker(
args, component, generate_endpoint, config
)
handler: VllmBaseWorker = VllmPDWorker(args, generate_endpoint, config)
elif args.worker_type == "decode":
handler = VllmDecodeWorker(args, component, generate_endpoint, config)
handler = VllmDecodeWorker(args, generate_endpoint, config)
await handler.async_init(runtime)
logger.info(f"Starting to serve the {args.endpoint} endpoint...")
......
......@@ -40,8 +40,8 @@ async def worker(runtime: DistributedRuntime):
async def init(runtime: DistributedRuntime, ns: str):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
Create and serve the `generate` endpoint using the distributed runtime.
Multiple endpoints can be served from a single worker.
"""
endpoint = runtime.endpoint(f"{ns}.backend.generate")
print("Started server instance")
......
......@@ -45,8 +45,8 @@ async def graceful_shutdown(runtime: DistributedRuntime):
async def init(runtime: DistributedRuntime, ns: str):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
Create and serve the `generate` endpoint using the distributed runtime.
Multiple endpoints can be served from a single worker.
"""
endpoint = runtime.endpoint(f"{ns}.backend.generate")
print("Started server instance")
......
......@@ -38,8 +38,8 @@ class RequestHandler:
@dynamo_worker()
async def worker(runtime: DistributedRuntime):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
Create and serve the `generate` endpoint using the distributed runtime.
Multiple endpoints can be served from a single worker.
"""
endpoint = runtime.endpoint("dynamo.backend.generate")
await endpoint.serve_endpoint(RequestHandler().generate)
......
......@@ -150,7 +150,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(llm::entrypoint::run_input, m)?)?;
m.add_class::<DistributedRuntime>()?;
m.add_class::<Component>()?;
m.add_class::<Endpoint>()?;
m.add_class::<ModelCardInstanceId>()?;
m.add_class::<Client>()?;
......@@ -460,13 +459,6 @@ struct CancellationToken {
inner: rs::CancellationToken,
}
#[pyclass]
#[derive(Clone)]
struct Component {
inner: rs::component::Component,
event_loop: PyObject,
}
#[pyclass]
#[derive(Clone)]
struct Endpoint {
......@@ -774,17 +766,6 @@ impl DistributedRuntime {
}
}
#[pymethods]
impl Component {
fn endpoint(&self, name: String) -> PyResult<Endpoint> {
let inner = self.inner.endpoint(name);
Ok(Endpoint {
inner,
event_loop: self.event_loop.clone(),
})
}
}
#[pymethods]
impl Endpoint {
#[pyo3(signature = (generator, graceful_shutdown = true, metrics_labels = None, health_check_payload = None))]
......@@ -907,17 +888,6 @@ impl Endpoint {
Ok(())
})
}
/// Get the parent Component.
///
/// Note: To avoid duplicate metrics registries, reuse the returned Component for
/// multiple endpoints: `component.endpoint("ep1")`, `component.endpoint("ep2")`.
fn component(&self) -> Component {
Component {
inner: self.inner.component().clone(),
event_loop: self.event_loop.clone(),
}
}
}
#[pymethods]
......
......@@ -9,7 +9,7 @@ use std::sync::mpsc;
use tokio_stream::StreamExt;
use super::*;
use crate::Component;
use crate::Endpoint;
use llm_rs::kv_router::protocols::compute_block_hash_for_seq;
use rs::pipeline::{AsyncEngine, SingleIn};
use rs::protocols::annotated::Annotated as RsAnnotated;
......@@ -86,14 +86,14 @@ impl WorkerMetricsPublisher {
})
}
#[pyo3(signature = (component))]
#[pyo3(signature = (endpoint))]
fn create_endpoint<'p>(
&self,
py: Python<'p>,
component: Component,
endpoint: Endpoint,
) -> PyResult<Bound<'p, PyAny>> {
let rs_publisher = self.inner.clone();
let rs_component = component.inner.clone();
let rs_component = endpoint.inner.component().clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
rs_publisher
.create_endpoint(rs_component)
......@@ -127,9 +127,9 @@ pub(crate) struct KvEventPublisher {
#[pymethods]
impl KvEventPublisher {
#[new]
#[pyo3(signature = (component, worker_id=0, kv_block_size=0, dp_rank=0, enable_local_indexer=false, zmq_endpoint=None, zmq_topic=None))]
#[pyo3(signature = (endpoint, worker_id=0, kv_block_size=0, dp_rank=0, enable_local_indexer=false, zmq_endpoint=None, zmq_topic=None))]
fn new(
component: Component,
endpoint: Endpoint,
worker_id: WorkerId,
kv_block_size: usize,
dp_rank: DpRank,
......@@ -139,8 +139,8 @@ impl KvEventPublisher {
) -> PyResult<Self> {
let _ = worker_id;
let source_config = zmq_endpoint.map(|endpoint| KvEventSourceConfig::Zmq {
endpoint,
let source_config = zmq_endpoint.map(|ep| KvEventSourceConfig::Zmq {
endpoint: ep,
topic: zmq_topic.unwrap_or_default(),
});
......@@ -148,8 +148,11 @@ impl KvEventPublisher {
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
}
// Extract component from endpoint
let component = endpoint.inner.component().clone();
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new_with_local_indexer(
component.inner,
component,
kv_block_size as u32,
source_config,
enable_local_indexer,
......
......@@ -111,19 +111,6 @@ class DistributedRuntime:
"""
...
class Component:
"""
A component is a collection of endpoints
"""
...
def endpoint(self, name: str) -> Endpoint:
"""
Create an endpoint
"""
...
class Endpoint:
"""
......@@ -190,25 +177,6 @@ class Endpoint:
"""
...
def component(self) -> Component:
"""
Get the parent Component that this endpoint belongs to.
Returns:
Component: The parent component
Note:
To avoid duplicate metrics registries, reuse the returned Component for
multiple endpoints: component.endpoint("ep1"), component.endpoint("ep2")
Example:
endpoint = runtime.endpoint("demo.backend.generate")
component = endpoint.component()
health_endpoint = component.endpoint("health") # Reuse component
"""
...
class Client:
"""
A client capable of calling served instances of an endpoint
......@@ -404,14 +372,15 @@ class WorkerMetricsPublisher:
Create a `WorkerMetricsPublisher` object
"""
async def create_endpoint(self, component: Component) -> None:
async def create_endpoint(self, endpoint: Endpoint) -> None:
"""
Create the NATS endpoint for metrics publishing. Must be awaited.
Initialize the NATS endpoint for publishing worker metrics. Must be awaited.
Only service created through this method will interact with KV router of the same component.
Extracts component information from the endpoint to set up metrics publishing
on the correct NATS subject for routing decisions.
Args:
component: The component to create the endpoint for
endpoint: The endpoint to extract component information from for metrics publishing
"""
def publish(
......@@ -575,7 +544,7 @@ class KvIndexer:
...
def __init__(self, component: Component, block_size: int) -> None:
def __init__(self, endpoint: Endpoint, block_size: int) -> None:
"""
Create a `KvIndexer` object
"""
......@@ -622,7 +591,7 @@ class ApproxKvIndexer:
def __init__(
self,
component: Component,
endpoint: Endpoint,
kv_block_size: int,
router_ttl_secs: float = 120.0,
router_max_tree_size: int = 1048576,
......@@ -689,7 +658,7 @@ class KvEventPublisher:
def __init__(
self,
component: Component,
endpoint: Endpoint,
worker_id: int = 0,
kv_block_size: int = 0,
dp_rank: int = 0,
......@@ -706,8 +675,8 @@ class KvEventPublisher:
When zmq_endpoint is None, events are pushed manually via publish_stored/publish_removed.
Args:
component: The component to publish events for
worker_id: The worker ID (unused, inferred from component)
endpoint: The endpoint to extract component information from for event publishing
worker_id: The worker ID (unused, inferred from endpoint)
kv_block_size: The KV block size (must be > 0)
dp_rank: The data parallel rank (defaults to 0)
enable_local_indexer: Enable worker-local KV indexer
......@@ -1612,7 +1581,6 @@ class VirtualConnectorClient:
__all__ = [
"Client",
"Component",
"Context",
"KserveGrpcService",
"ModelDeploymentCard",
......
......@@ -11,11 +11,9 @@ from pydantic import BaseModel, ValidationError
# List all the classes in the _core module for re-export
# import * causes "unable to detect undefined names"
from dynamo._core import Client as Client
from dynamo._core import Component as Component
from dynamo._core import Context as Context
from dynamo._core import DistributedRuntime as DistributedRuntime
from dynamo._core import Endpoint as Endpoint
from dynamo._core import ModelDeploymentCard as ModelDeploymentCard
def dynamo_worker(enable_nats: bool = True):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment