Unverified Commit 80cac7c1 authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

feat: Remove Component from public (#6403)


Signed-off-by: default avatartzulingk@nvidia.com <tzulingk@nvidia.com>
parent cb55766c
......@@ -71,9 +71,6 @@ async def main(runtime: DistributedRuntime, args):
logger.info("=" * 60)
# Create the GlobalPlanner component (get from first endpoint)
component = runtime.endpoint(f"{namespace}.GlobalPlanner.scale_request").component()
# Get K8s namespace (where GlobalPlanner pod is running)
k8s_namespace = os.environ.get("POD_NAMESPACE", "default")
logger.info(f"Running in Kubernetes namespace: {k8s_namespace}")
......@@ -87,7 +84,7 @@ async def main(runtime: DistributedRuntime, args):
# Serve scale_request endpoint
logger.info("Serving endpoints...")
scale_endpoint = component.endpoint("scale_request")
scale_endpoint = runtime.endpoint(f"{namespace}.GlobalPlanner.scale_request")
await scale_endpoint.serve_endpoint(handler.scale_request)
logger.info(" ✓ scale_request - Receives scaling requests from Planners")
......@@ -101,7 +98,7 @@ async def main(runtime: DistributedRuntime, args):
"managed_namespaces": args.managed_namespaces or "all",
}
health_endpoint = component.endpoint("health")
health_endpoint = runtime.endpoint(f"{namespace}.GlobalPlanner.health")
await health_endpoint.serve_endpoint(health_check)
logger.info(" ✓ health - Health check endpoint")
......
......@@ -74,8 +74,9 @@ async def worker(runtime: DistributedRuntime):
prefill_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component_name}.prefill_generate"
)
component = prefill_endpoint.component()
decode_endpoint = component.endpoint("decode_generate")
decode_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component_name}.decode_generate"
)
logger.info("Registering as prefill worker...")
# Register as prefill worker - frontend will send prefill requests here
......
......@@ -185,10 +185,9 @@ async def worker(runtime: DistributedRuntime):
)
await handler.initialize()
# Create endpoints (get component from first endpoint to avoid duplicate metrics registries)
# Create endpoints
generate_endpoint = runtime.endpoint(f"{config.namespace}.router.generate")
component = generate_endpoint.component()
best_worker_endpoint = component.endpoint("best_worker_id")
best_worker_endpoint = runtime.endpoint(f"{config.namespace}.router.best_worker_id")
logger.debug("Starting to serve endpoints...")
......
......@@ -304,13 +304,13 @@ async def init(
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
# Setup metrics and KV events for ALL nodes (including non-leader)
# Non-leader nodes need KV event publishing for their local DP ranks
publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint
engine, config, generate_endpoint
)
# Record model load time immediately after publisher setup (which creates the gauges)
......@@ -327,7 +327,7 @@ async def init(
ready_event = asyncio.Event()
handler = DecodeWorkerHandler(
component, engine, config, publisher, generate_endpoint, shutdown_event
engine, config, publisher, generate_endpoint, shutdown_event
)
handler.register_engine_routes(runtime)
......@@ -395,13 +395,13 @@ async def init_prefill(
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
# Setup metrics and KV events for ALL nodes (including non-leader)
# Non-leader nodes need KV event publishing for their local DP ranks
publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint
engine, config, generate_endpoint
)
# Handle non-leader nodes (multi-node parallelism)
......@@ -415,7 +415,7 @@ async def init_prefill(
await _warmup_prefill_engine(engine, server_args)
handler = PrefillWorkerHandler(
component, engine, config, publisher, generate_endpoint, shutdown_event
engine, config, publisher, generate_endpoint, shutdown_event
)
handler.register_engine_routes(runtime)
......@@ -487,13 +487,13 @@ async def init_diffusion(
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
# Setup metrics and KV events for ALL nodes (including non-leader)
# Non-leader nodes need KV event publishing for their local DP ranks
publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint
engine, config, generate_endpoint
)
# Handle non-leader nodes (multi-node parallelism)
......@@ -506,7 +506,7 @@ async def init_diffusion(
ready_event = asyncio.Event()
handler = DiffusionWorkerHandler(
component, engine, config, publisher, generate_endpoint, shutdown_event
engine, config, publisher, generate_endpoint, shutdown_event
)
handler.register_engine_routes(runtime)
......@@ -567,20 +567,18 @@ async def init_embedding(
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
# publisher instantiates the metrics and kv event publishers
publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint
engine, config, generate_endpoint
)
# Readiness gate: requests wait until model is registered
ready_event = asyncio.Event()
handler = EmbeddingWorkerHandler(
component, engine, config, publisher, shutdown_event
)
handler = EmbeddingWorkerHandler(engine, config, publisher, shutdown_event)
health_check_payload = SglangHealthCheckPayload(
engine, use_text_input=dynamo_args.use_sglang_tokenizer
).to_dict()
......@@ -659,14 +657,13 @@ async def init_image_diffusion(
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
# Image diffusion doesn't have metrics publisher like LLM
# Could add custom metrics for images/sec, steps/sec later
handler = ImageDiffusionWorkerHandler(
component,
generator,
config,
publisher=None,
......@@ -744,11 +741,10 @@ async def init_video_generation(
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
handler = VideoGenerationWorkerHandler(
component,
generator,
config,
publisher=None,
......@@ -799,7 +795,7 @@ async def init_multimodal_processor(
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
# For processor, we need to connect to the encode worker
......@@ -809,9 +805,7 @@ async def init_multimodal_processor(
ready_event = asyncio.Event()
handler = MultimodalProcessorHandler(
component, config, encode_worker_client, shutdown_event
)
handler = MultimodalProcessorHandler(config, encode_worker_client, shutdown_event)
logging.info("Waiting for Encoder Worker Instances ...")
await encode_worker_client.wait_for_instances()
......@@ -858,7 +852,7 @@ async def init_multimodal_encode_worker(
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
# For encode worker, we need to connect to the downstream LLM worker
......@@ -866,9 +860,7 @@ async def init_multimodal_encode_worker(
f"{dynamo_args.namespace}.backend.generate"
).client()
handler = MultimodalEncodeWorkerHandler(
component, config, pd_worker_client, shutdown_event
)
handler = MultimodalEncodeWorkerHandler(config, pd_worker_client, shutdown_event)
await handler.async_init(runtime)
await pd_worker_client.wait_for_instances()
......@@ -912,7 +904,7 @@ async def init_multimodal_worker(
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
engine = sgl.Engine(server_args=server_args)
......@@ -923,12 +915,10 @@ async def init_multimodal_worker(
f"{dynamo_args.namespace}.prefill.generate"
).client()
handler = MultimodalWorkerHandler(
component, engine, config, prefill_client, shutdown_event
engine, config, prefill_client, shutdown_event
)
else:
handler = MultimodalWorkerHandler(
component, engine, config, None, shutdown_event
)
handler = MultimodalWorkerHandler(engine, config, None, shutdown_event)
await handler.async_init()
......@@ -968,10 +958,11 @@ async def init_multimodal_prefill_worker(
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
component = generate_endpoint.component()
handler = MultimodalPrefillWorkerHandler(engine, config, shutdown_event)
shutdown_endpoints[:] = [generate_endpoint]
handler = MultimodalPrefillWorkerHandler(component, engine, config, shutdown_event)
await handler.async_init()
health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict()
......
......@@ -21,7 +21,7 @@ from dynamo.common.utils.prometheus import (
register_engine_metrics_callback,
)
from dynamo.llm import KvEventPublisher, WorkerMetricsPublisher
from dynamo.runtime import Component, Endpoint
from dynamo.runtime import Endpoint
from dynamo.sglang.args import Config
......@@ -63,7 +63,6 @@ class DynamoSglangPublisher:
self,
engine: sgl.Engine,
config: Config,
component: Component,
generate_endpoint: Endpoint,
component_gauges: LLMBackendMetrics,
metrics_labels: Optional[List[Tuple[str, str]]] = None,
......@@ -73,7 +72,6 @@ class DynamoSglangPublisher:
Args:
engine: The SGLang engine instance.
config: SGLang configuration including server args.
component: The Dynamo runtime component.
generate_endpoint: The Dynamo endpoint for generation requests.
metrics_labels: Optional list of label key-value pairs for metrics.
component_gauges: LLM backend metrics instance (created via LLMBackendMetrics()).
......@@ -82,7 +80,6 @@ class DynamoSglangPublisher:
self.server_args = config.server_args
self.dynamo_args = config.dynamo_args
self.generate_endpoint = generate_endpoint
self.component = component
self.metrics_publisher = WorkerMetricsPublisher()
self.component_gauges = component_gauges
# Endpoint creation is deferred to async context in setup_sgl_metrics
......@@ -257,7 +254,7 @@ class DynamoSglangPublisher:
f"(connecting to {zmq_ep})"
)
publisher = KvEventPublisher(
component=self.component,
endpoint=self.generate_endpoint,
kv_block_size=self.server_args.page_size,
zmq_endpoint=zmq_ep,
zmq_topic="",
......@@ -322,7 +319,6 @@ def setup_prometheus_registry(
async def setup_sgl_metrics(
engine: sgl.Engine,
config: Config,
component: Component,
generate_endpoint: Endpoint,
) -> tuple[DynamoSglangPublisher, asyncio.Task, list[tuple[str, str]]]:
"""Create publisher, initialize metrics, and start the metrics publishing loop.
......@@ -330,7 +326,6 @@ async def setup_sgl_metrics(
Args:
engine: The SGLang engine instance.
config: SGLang configuration including server args.
component: The Dynamo runtime component.
generate_endpoint: The Dynamo endpoint for generation requests.
Returns:
......@@ -366,13 +361,12 @@ async def setup_sgl_metrics(
publisher = DynamoSglangPublisher(
engine,
config,
component,
generate_endpoint,
component_gauges=component_gauges,
metrics_labels=metrics_labels,
)
# Create endpoint in async context (must await before publishing)
await publisher.metrics_publisher.create_endpoint(component)
await publisher.metrics_publisher.create_endpoint(generate_endpoint)
logging.debug("SGLang metrics publisher endpoint created")
publisher.init_engine_metrics_publish()
......
......@@ -7,7 +7,7 @@ from typing import Optional
import sglang as sgl
from dynamo._core import Component, Context
from dynamo._core import Context
from dynamo.sglang.args import Config
from dynamo.sglang.protocol import EmbeddingRequest
from dynamo.sglang.publisher import DynamoSglangPublisher
......@@ -17,13 +17,12 @@ from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
class EmbeddingWorkerHandler(BaseWorkerHandler):
def __init__(
self,
component: Component,
engine: sgl.Engine,
config: Config,
publisher: Optional[DynamoSglangPublisher] = None,
shutdown_event: Optional[asyncio.Event] = None,
):
super().__init__(component, engine, config, publisher, None, shutdown_event)
super().__init__(engine, config, publisher, None, shutdown_event)
logging.info("Embedding worker handler initialized")
def cleanup(self):
......
......@@ -13,7 +13,7 @@ from typing import Any, AsyncGenerator, Dict, Optional, Tuple
import sglang as sgl
from sglang.srt.utils import get_local_ip_auto
from dynamo._core import Component, Context
from dynamo._core import Context
from dynamo.common.utils.input_params import InputParamManager
from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher
......@@ -30,18 +30,15 @@ class BaseGenerativeHandler(ABC):
def __init__(
self,
component: Component,
config: Config,
publisher: Optional[DynamoSglangPublisher] = None,
) -> None:
"""Initialize base generative handler.
Args:
component: The Dynamo runtime component.
config: SGLang and Dynamo configuration.
publisher: Optional metrics publisher for the worker.
"""
self.component = component
self.config = config
# Set up metrics and KV publishers
......@@ -98,7 +95,6 @@ class BaseWorkerHandler(BaseGenerativeHandler):
def __init__(
self,
component: Component,
engine: sgl.Engine,
config: Config,
publisher: Optional[DynamoSglangPublisher] = None,
......@@ -108,7 +104,6 @@ class BaseWorkerHandler(BaseGenerativeHandler):
"""Initialize base worker handler.
Args:
component: The Dynamo runtime component.
engine: The SGLang engine instance.
config: SGLang and Dynamo configuration.
publisher: Optional metrics publisher for the worker.
......@@ -116,7 +111,7 @@ class BaseWorkerHandler(BaseGenerativeHandler):
shutdown_event: Optional event to signal shutdown.
"""
# Call parent constructor
super().__init__(component, config, publisher)
super().__init__(config, publisher)
# LLM-specific initialization
self.engine = engine
......
......@@ -13,7 +13,7 @@ from typing import Any, AsyncGenerator, Optional
import torch
from PIL import Image
from dynamo._core import Component, Context
from dynamo._core import Context
from dynamo.common.storage import upload_to_fs
from dynamo.sglang.args import Config
from dynamo.sglang.protocol import CreateImageRequest, ImageData, ImagesResponse, NvExt
......@@ -34,7 +34,6 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler):
def __init__(
self,
component: Component,
generator: Any, # DiffGenerator, not sgl.Engine
config: Config,
publisher: Optional[DynamoSglangPublisher] = None,
......@@ -43,13 +42,12 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler):
"""Initialize diffusion worker handler.
Args:
component: The Dynamo runtime component.
generator: The SGLang DiffGenerator instance.
config: SGLang and Dynamo configuration.
publisher: Optional metrics publisher (not used for diffusion currently).
fs: Optional fsspec filesystem for primary image storage.
"""
super().__init__(component, config, publisher)
super().__init__(config, publisher)
self.generator = generator # DiffGenerator, not Engine
self.fs = fs
......
......@@ -8,7 +8,7 @@ from typing import Any, AsyncGenerator, Dict, Optional
import sglang as sgl
from dynamo._core import Component, Context
from dynamo._core import Context
from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.sglang.args import Config, DisaggregationMode
from dynamo.sglang.publisher import DynamoSglangPublisher
......@@ -20,7 +20,6 @@ class DecodeWorkerHandler(BaseWorkerHandler):
def __init__(
self,
component: Component,
engine: sgl.Engine,
config: Config,
publisher: DynamoSglangPublisher,
......@@ -30,7 +29,6 @@ class DecodeWorkerHandler(BaseWorkerHandler):
"""Initialize decode worker handler.
Args:
component: The Dynamo runtime component.
engine: The SGLang engine instance.
config: SGLang and Dynamo configuration.
publisher: Metrics publisher for the worker.
......@@ -38,7 +36,6 @@ class DecodeWorkerHandler(BaseWorkerHandler):
generate_endpoint: The endpoint handle for discovery registration.
"""
super().__init__(
component,
engine,
config,
publisher,
......
......@@ -6,7 +6,7 @@ from typing import Any, AsyncGenerator, Dict, Optional
import sglang as sgl
from dynamo._core import Component, Context
from dynamo._core import Context
from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.llm.decode_handler import DecodeWorkerHandler
......@@ -19,7 +19,6 @@ class DiffusionWorkerHandler(DecodeWorkerHandler):
def __init__(
self,
component: Component,
engine: sgl.Engine,
config: Config,
publisher: DynamoSglangPublisher = None,
......@@ -29,16 +28,13 @@ class DiffusionWorkerHandler(DecodeWorkerHandler):
"""Initialize diffusion worker handler.
Args:
component: The Dynamo runtime component.
engine: SGLang engine with diffusion algorithm configured.
config: SGLang and Dynamo configuration.
publisher: Optional metrics publisher.
generate_endpoint: The endpoint handle for discovery.
shutdown_event: Optional event to signal shutdown.
"""
super().__init__(
component, engine, config, publisher, generate_endpoint, shutdown_event
)
super().__init__(engine, config, publisher, generate_endpoint, shutdown_event)
# Validate that diffusion algorithm is configured
if (
......
......@@ -7,7 +7,7 @@ from typing import Any, AsyncGenerator, Dict, Optional
import sglang as sgl
from dynamo._core import Component, Context
from dynamo._core import Context
from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
......@@ -18,7 +18,6 @@ class PrefillWorkerHandler(BaseWorkerHandler):
def __init__(
self,
component: Component,
engine: sgl.Engine,
config: Config,
publisher: DynamoSglangPublisher,
......@@ -28,7 +27,6 @@ class PrefillWorkerHandler(BaseWorkerHandler):
"""Initialize prefill worker handler.
Args:
component: The Dynamo runtime component.
engine: The SGLang engine instance.
config: SGLang and Dynamo configuration.
publisher: The SGLang publisher instance.
......@@ -37,9 +35,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
"""
self.engine = engine
self.bootstrap_host, self.bootstrap_port = self._get_bootstrap_info(self.engine)
super().__init__(
component, engine, config, publisher, generate_endpoint, shutdown_event
)
super().__init__(engine, config, publisher, generate_endpoint, shutdown_event)
self._consume_tasks = set()
logging.info(
f"Prefill worker handler initialized - bootstrap host: {self.bootstrap_host}, bootstrap port: {self.bootstrap_port}"
......
......@@ -16,7 +16,7 @@ from sglang.srt.parser.conversation import chat_templates
from transformers import AutoTokenizer
import dynamo.nixl_connect as connect
from dynamo._core import Client, Component, Context
from dynamo._core import Client, Context
from dynamo.runtime import DistributedRuntime
from dynamo.sglang.args import Config
from dynamo.sglang.protocol import SglangMultimodalRequest
......@@ -46,14 +46,11 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
def __init__(
self,
component: Component,
config: Config,
pd_worker_client: Client,
shutdown_event: Optional[asyncio.Event] = None,
) -> None:
super().__init__(
component, engine=None, config=config, shutdown_event=shutdown_event
)
super().__init__(engine=None, config=config, shutdown_event=shutdown_event)
self.pd_worker_client = pd_worker_client
self.model = config.server_args.model_path
......
......@@ -10,7 +10,7 @@ from typing import Any, Dict, Optional
from transformers import AutoTokenizer
from dynamo._core import Client, Component, Context
from dynamo._core import Client, Context
from dynamo.sglang.args import Config
from dynamo.sglang.multimodal_utils import (
multimodal_request_to_sglang,
......@@ -34,14 +34,11 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
def __init__(
self,
component: Component,
config: Config,
encode_worker_client: Client,
shutdown_event: Optional[asyncio.Event] = None,
):
super().__init__(
component, engine=None, config=config, shutdown_event=shutdown_event
)
super().__init__(engine=None, config=config, shutdown_event=shutdown_event)
self.encode_worker_client = encode_worker_client
self.chat_template = getattr(config.server_args, "chat_template", "qwen2-vl")
self.model = config.server_args.model_path
......
......@@ -10,7 +10,7 @@ import sglang as sgl
import torch
import dynamo.nixl_connect as connect
from dynamo._core import Client, Component, Context
from dynamo._core import Client, Context
from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.sglang.args import Config, DisaggregationMode
from dynamo.sglang.protocol import (
......@@ -253,13 +253,12 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
def __init__(
self,
component: Component,
engine: sgl.Engine,
config: Config,
prefill_client: Client = None,
shutdown_event: Optional[asyncio.Event] = None,
):
super().__init__(component, engine, config, None, None, shutdown_event)
super().__init__(engine, config, None, None, shutdown_event)
# Initialize processors
self.embeddings_processor = EmbeddingsProcessor()
......@@ -435,12 +434,11 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
def __init__(
self,
component: Component,
engine: sgl.Engine,
config: Config,
shutdown_event: Optional[asyncio.Event] = None,
):
super().__init__(component, engine, config, None, None, shutdown_event)
super().__init__(engine, config, None, None, shutdown_event)
# Initialize processors
self.embeddings_processor = EmbeddingsProcessor()
......
......@@ -11,7 +11,7 @@ from typing import Any, AsyncGenerator, Optional
import torch
from dynamo._core import Component, Context
from dynamo._core import Context
from dynamo.common.storage import upload_to_fs
from dynamo.sglang.args import Config
from dynamo.sglang.protocol import (
......@@ -35,7 +35,6 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler):
def __init__(
self,
component: Component,
generator: Any, # DiffGenerator, not sgl.Engine
config: Config,
publisher: Optional[DynamoSglangPublisher] = None,
......@@ -44,14 +43,13 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler):
"""Initialize video generation worker handler.
Args:
component: The Dynamo runtime component.
generator: The SGLang DiffGenerator instance.
config: SGLang and Dynamo configuration.
publisher: Optional metrics publisher (not used for video currently).
fs: Optional fsspec filesystem for primary video storage.
"""
# Call parent constructor for common setup
super().__init__(component, config, publisher)
super().__init__(config, publisher)
# Video generation-specific initialization
self.generator = generator # DiffGenerator, not Engine
......
......@@ -23,12 +23,6 @@ pytestmark = [
]
@pytest.fixture
def mock_component():
"""Mock Dynamo Component."""
return MagicMock()
@pytest.fixture
def mock_generator():
"""Mock SGLang DiffGenerator."""
......@@ -67,12 +61,9 @@ def mock_context():
@pytest.fixture
def handler(
mock_component, mock_generator, mock_config, mock_fs
) -> ImageDiffusionWorkerHandler:
def handler(mock_generator, mock_config, mock_fs) -> ImageDiffusionWorkerHandler:
"""Create ImageDiffusionWorkerHandler instance."""
return ImageDiffusionWorkerHandler(
component=mock_component,
generator=mock_generator,
config=mock_config,
publisher=None,
......@@ -90,9 +81,7 @@ class TestImageDiffusionWorkerHandler:
assert handler.fs_url == "file:///tmp/images"
assert handler.base_url == "file:///tmp/images"
def test_initialization_with_url_base(
self, mock_component, mock_generator, mock_fs
):
def test_initialization_with_url_base(self, mock_generator, mock_fs):
"""Test handler initialization with URL base."""
config = MagicMock()
config.dynamo_args = MagicMock()
......@@ -100,7 +89,6 @@ class TestImageDiffusionWorkerHandler:
config.dynamo_args.media_output_http_url = "http://localhost:8008/images"
handler = ImageDiffusionWorkerHandler(
component=mock_component,
generator=mock_generator,
config=config,
publisher=None,
......
......@@ -296,9 +296,8 @@ class Publisher:
def __init__(
self,
component,
endpoint,
engine,
kv_listener,
worker_id,
kv_block_size,
metrics_labels,
......@@ -306,9 +305,8 @@ class Publisher:
zmq_endpoint: Optional[str] = None,
enable_local_indexer: bool = False,
):
self.component = component
self.endpoint = endpoint
self.engine = engine
self.kv_listener = kv_listener
self.worker_id = worker_id
self.kv_block_size = kv_block_size
self.max_window_size = None
......@@ -356,7 +354,7 @@ class Publisher:
if self.metrics_publisher is None:
logging.error("KV metrics publisher not initialized!")
return
await self.metrics_publisher.create_endpoint(self.component)
await self.metrics_publisher.create_endpoint(self.endpoint)
def initialize(self):
# Setup the metrics publisher
......@@ -385,9 +383,9 @@ class Publisher:
self.kv_event_publishers = {}
for rank in range(self.attention_dp_size):
self.kv_event_publishers[rank] = KvEventPublisher(
self.kv_listener,
self.worker_id,
self.kv_block_size,
endpoint=self.endpoint,
worker_id=self.worker_id,
kv_block_size=self.kv_block_size,
dp_rank=rank,
enable_local_indexer=self.enable_local_indexer,
)
......@@ -760,9 +758,8 @@ class Publisher:
@asynccontextmanager
async def get_publisher(
component,
endpoint,
engine,
kv_listener,
worker_id,
kv_block_size,
metrics_labels,
......@@ -771,9 +768,8 @@ async def get_publisher(
enable_local_indexer: bool = False,
):
publisher = Publisher(
component,
endpoint,
engine,
kv_listener,
worker_id,
kv_block_size,
metrics_labels,
......
......@@ -55,7 +55,6 @@ class RequestHandlerConfig:
Configuration for the request handler
"""
component: object
engine: TensorRTLLMEngine
default_sampling_params: SamplingParams
publisher: Publisher
......@@ -89,7 +88,6 @@ class HandlerBase(BaseGenerativeHandler):
def __init__(self, config: RequestHandlerConfig):
self.engine = config.engine
self.component = config.component
self.default_sampling_params = config.default_sampling_params
self.publisher = config.publisher
self.metrics_collector = config.metrics_collector
......
......@@ -13,7 +13,7 @@ import time
import uuid
from typing import Any, AsyncGenerator, Optional
from dynamo._core import Component, Context
from dynamo._core import Context
from dynamo.common.protocols.video_protocol import (
NvCreateVideoRequest,
NvVideosResponse,
......@@ -42,18 +42,15 @@ class VideoGenerationHandler(BaseGenerativeHandler):
def __init__(
self,
component: Component,
engine: DiffusionEngine,
config: DiffusionConfig,
):
"""Initialize the handler.
Args:
component: The Dynamo runtime component.
engine: The DiffusionEngine instance.
config: Diffusion generation configuration.
"""
self.component = component
self.engine = engine
self.config = config
if not config.media_output_fs_url:
......
......@@ -597,7 +597,6 @@ class TestVideoHandlerConcurrency:
return_value=MagicMock(),
):
handler = VideoGenerationHandler(
component=MagicMock(),
engine=mock_engine,
config=config,
)
......@@ -684,7 +683,6 @@ class TestVideoHandlerResponseFormats:
return_value=MagicMock(),
):
handler = VideoGenerationHandler(
component=MagicMock(),
engine=mock_engine,
config=config,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment