Unverified Commit 80cac7c1 authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

feat: Remove Component from public (#6403)


Signed-off-by: default avatartzulingk@nvidia.com <tzulingk@nvidia.com>
parent cb55766c
...@@ -71,9 +71,6 @@ async def main(runtime: DistributedRuntime, args): ...@@ -71,9 +71,6 @@ async def main(runtime: DistributedRuntime, args):
logger.info("=" * 60) logger.info("=" * 60)
# Create the GlobalPlanner component (get from first endpoint)
component = runtime.endpoint(f"{namespace}.GlobalPlanner.scale_request").component()
# Get K8s namespace (where GlobalPlanner pod is running) # Get K8s namespace (where GlobalPlanner pod is running)
k8s_namespace = os.environ.get("POD_NAMESPACE", "default") k8s_namespace = os.environ.get("POD_NAMESPACE", "default")
logger.info(f"Running in Kubernetes namespace: {k8s_namespace}") logger.info(f"Running in Kubernetes namespace: {k8s_namespace}")
...@@ -87,7 +84,7 @@ async def main(runtime: DistributedRuntime, args): ...@@ -87,7 +84,7 @@ async def main(runtime: DistributedRuntime, args):
# Serve scale_request endpoint # Serve scale_request endpoint
logger.info("Serving endpoints...") logger.info("Serving endpoints...")
scale_endpoint = component.endpoint("scale_request") scale_endpoint = runtime.endpoint(f"{namespace}.GlobalPlanner.scale_request")
await scale_endpoint.serve_endpoint(handler.scale_request) await scale_endpoint.serve_endpoint(handler.scale_request)
logger.info(" ✓ scale_request - Receives scaling requests from Planners") logger.info(" ✓ scale_request - Receives scaling requests from Planners")
...@@ -101,7 +98,7 @@ async def main(runtime: DistributedRuntime, args): ...@@ -101,7 +98,7 @@ async def main(runtime: DistributedRuntime, args):
"managed_namespaces": args.managed_namespaces or "all", "managed_namespaces": args.managed_namespaces or "all",
} }
health_endpoint = component.endpoint("health") health_endpoint = runtime.endpoint(f"{namespace}.GlobalPlanner.health")
await health_endpoint.serve_endpoint(health_check) await health_endpoint.serve_endpoint(health_check)
logger.info(" ✓ health - Health check endpoint") logger.info(" ✓ health - Health check endpoint")
......
...@@ -74,8 +74,9 @@ async def worker(runtime: DistributedRuntime): ...@@ -74,8 +74,9 @@ async def worker(runtime: DistributedRuntime):
prefill_endpoint = runtime.endpoint( prefill_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component_name}.prefill_generate" f"{config.namespace}.{config.component_name}.prefill_generate"
) )
component = prefill_endpoint.component() decode_endpoint = runtime.endpoint(
decode_endpoint = component.endpoint("decode_generate") f"{config.namespace}.{config.component_name}.decode_generate"
)
logger.info("Registering as prefill worker...") logger.info("Registering as prefill worker...")
# Register as prefill worker - frontend will send prefill requests here # Register as prefill worker - frontend will send prefill requests here
......
...@@ -185,10 +185,9 @@ async def worker(runtime: DistributedRuntime): ...@@ -185,10 +185,9 @@ async def worker(runtime: DistributedRuntime):
) )
await handler.initialize() await handler.initialize()
# Create endpoints (get component from first endpoint to avoid duplicate metrics registries) # Create endpoints
generate_endpoint = runtime.endpoint(f"{config.namespace}.router.generate") generate_endpoint = runtime.endpoint(f"{config.namespace}.router.generate")
component = generate_endpoint.component() best_worker_endpoint = runtime.endpoint(f"{config.namespace}.router.best_worker_id")
best_worker_endpoint = component.endpoint("best_worker_id")
logger.debug("Starting to serve endpoints...") logger.debug("Starting to serve endpoints...")
......
...@@ -304,13 +304,13 @@ async def init( ...@@ -304,13 +304,13 @@ async def init(
generate_endpoint = runtime.endpoint( generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}" f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint] shutdown_endpoints[:] = [generate_endpoint]
# Setup metrics and KV events for ALL nodes (including non-leader) # Setup metrics and KV events for ALL nodes (including non-leader)
# Non-leader nodes need KV event publishing for their local DP ranks # Non-leader nodes need KV event publishing for their local DP ranks
publisher, metrics_task, metrics_labels = await setup_sgl_metrics( publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint engine, config, generate_endpoint
) )
# Record model load time immediately after publisher setup (which creates the gauges) # Record model load time immediately after publisher setup (which creates the gauges)
...@@ -327,7 +327,7 @@ async def init( ...@@ -327,7 +327,7 @@ async def init(
ready_event = asyncio.Event() ready_event = asyncio.Event()
handler = DecodeWorkerHandler( handler = DecodeWorkerHandler(
component, engine, config, publisher, generate_endpoint, shutdown_event engine, config, publisher, generate_endpoint, shutdown_event
) )
handler.register_engine_routes(runtime) handler.register_engine_routes(runtime)
...@@ -395,13 +395,13 @@ async def init_prefill( ...@@ -395,13 +395,13 @@ async def init_prefill(
generate_endpoint = runtime.endpoint( generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}" f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint] shutdown_endpoints[:] = [generate_endpoint]
# Setup metrics and KV events for ALL nodes (including non-leader) # Setup metrics and KV events for ALL nodes (including non-leader)
# Non-leader nodes need KV event publishing for their local DP ranks # Non-leader nodes need KV event publishing for their local DP ranks
publisher, metrics_task, metrics_labels = await setup_sgl_metrics( publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint engine, config, generate_endpoint
) )
# Handle non-leader nodes (multi-node parallelism) # Handle non-leader nodes (multi-node parallelism)
...@@ -415,7 +415,7 @@ async def init_prefill( ...@@ -415,7 +415,7 @@ async def init_prefill(
await _warmup_prefill_engine(engine, server_args) await _warmup_prefill_engine(engine, server_args)
handler = PrefillWorkerHandler( handler = PrefillWorkerHandler(
component, engine, config, publisher, generate_endpoint, shutdown_event engine, config, publisher, generate_endpoint, shutdown_event
) )
handler.register_engine_routes(runtime) handler.register_engine_routes(runtime)
...@@ -487,13 +487,13 @@ async def init_diffusion( ...@@ -487,13 +487,13 @@ async def init_diffusion(
generate_endpoint = runtime.endpoint( generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}" f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint] shutdown_endpoints[:] = [generate_endpoint]
# Setup metrics and KV events for ALL nodes (including non-leader) # Setup metrics and KV events for ALL nodes (including non-leader)
# Non-leader nodes need KV event publishing for their local DP ranks # Non-leader nodes need KV event publishing for their local DP ranks
publisher, metrics_task, metrics_labels = await setup_sgl_metrics( publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint engine, config, generate_endpoint
) )
# Handle non-leader nodes (multi-node parallelism) # Handle non-leader nodes (multi-node parallelism)
...@@ -506,7 +506,7 @@ async def init_diffusion( ...@@ -506,7 +506,7 @@ async def init_diffusion(
ready_event = asyncio.Event() ready_event = asyncio.Event()
handler = DiffusionWorkerHandler( handler = DiffusionWorkerHandler(
component, engine, config, publisher, generate_endpoint, shutdown_event engine, config, publisher, generate_endpoint, shutdown_event
) )
handler.register_engine_routes(runtime) handler.register_engine_routes(runtime)
...@@ -567,20 +567,18 @@ async def init_embedding( ...@@ -567,20 +567,18 @@ async def init_embedding(
generate_endpoint = runtime.endpoint( generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}" f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint] shutdown_endpoints[:] = [generate_endpoint]
# publisher instantiates the metrics and kv event publishers # publisher instantiates the metrics and kv event publishers
publisher, metrics_task, metrics_labels = await setup_sgl_metrics( publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint engine, config, generate_endpoint
) )
# Readiness gate: requests wait until model is registered # Readiness gate: requests wait until model is registered
ready_event = asyncio.Event() ready_event = asyncio.Event()
handler = EmbeddingWorkerHandler( handler = EmbeddingWorkerHandler(engine, config, publisher, shutdown_event)
component, engine, config, publisher, shutdown_event
)
health_check_payload = SglangHealthCheckPayload( health_check_payload = SglangHealthCheckPayload(
engine, use_text_input=dynamo_args.use_sglang_tokenizer engine, use_text_input=dynamo_args.use_sglang_tokenizer
).to_dict() ).to_dict()
...@@ -659,14 +657,13 @@ async def init_image_diffusion( ...@@ -659,14 +657,13 @@ async def init_image_diffusion(
generate_endpoint = runtime.endpoint( generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}" f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint] shutdown_endpoints[:] = [generate_endpoint]
# Image diffusion doesn't have metrics publisher like LLM # Image diffusion doesn't have metrics publisher like LLM
# Could add custom metrics for images/sec, steps/sec later # Could add custom metrics for images/sec, steps/sec later
handler = ImageDiffusionWorkerHandler( handler = ImageDiffusionWorkerHandler(
component,
generator, generator,
config, config,
publisher=None, publisher=None,
...@@ -744,11 +741,10 @@ async def init_video_generation( ...@@ -744,11 +741,10 @@ async def init_video_generation(
generate_endpoint = runtime.endpoint( generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}" f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint] shutdown_endpoints[:] = [generate_endpoint]
handler = VideoGenerationWorkerHandler( handler = VideoGenerationWorkerHandler(
component,
generator, generator,
config, config,
publisher=None, publisher=None,
...@@ -799,7 +795,7 @@ async def init_multimodal_processor( ...@@ -799,7 +795,7 @@ async def init_multimodal_processor(
generate_endpoint = runtime.endpoint( generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}" f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint] shutdown_endpoints[:] = [generate_endpoint]
# For processor, we need to connect to the encode worker # For processor, we need to connect to the encode worker
...@@ -809,9 +805,7 @@ async def init_multimodal_processor( ...@@ -809,9 +805,7 @@ async def init_multimodal_processor(
ready_event = asyncio.Event() ready_event = asyncio.Event()
handler = MultimodalProcessorHandler( handler = MultimodalProcessorHandler(config, encode_worker_client, shutdown_event)
component, config, encode_worker_client, shutdown_event
)
logging.info("Waiting for Encoder Worker Instances ...") logging.info("Waiting for Encoder Worker Instances ...")
await encode_worker_client.wait_for_instances() await encode_worker_client.wait_for_instances()
...@@ -858,7 +852,7 @@ async def init_multimodal_encode_worker( ...@@ -858,7 +852,7 @@ async def init_multimodal_encode_worker(
generate_endpoint = runtime.endpoint( generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}" f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint] shutdown_endpoints[:] = [generate_endpoint]
# For encode worker, we need to connect to the downstream LLM worker # For encode worker, we need to connect to the downstream LLM worker
...@@ -866,9 +860,7 @@ async def init_multimodal_encode_worker( ...@@ -866,9 +860,7 @@ async def init_multimodal_encode_worker(
f"{dynamo_args.namespace}.backend.generate" f"{dynamo_args.namespace}.backend.generate"
).client() ).client()
handler = MultimodalEncodeWorkerHandler( handler = MultimodalEncodeWorkerHandler(config, pd_worker_client, shutdown_event)
component, config, pd_worker_client, shutdown_event
)
await handler.async_init(runtime) await handler.async_init(runtime)
await pd_worker_client.wait_for_instances() await pd_worker_client.wait_for_instances()
...@@ -912,7 +904,7 @@ async def init_multimodal_worker( ...@@ -912,7 +904,7 @@ async def init_multimodal_worker(
generate_endpoint = runtime.endpoint( generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}" f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint] shutdown_endpoints[:] = [generate_endpoint]
engine = sgl.Engine(server_args=server_args) engine = sgl.Engine(server_args=server_args)
...@@ -923,12 +915,10 @@ async def init_multimodal_worker( ...@@ -923,12 +915,10 @@ async def init_multimodal_worker(
f"{dynamo_args.namespace}.prefill.generate" f"{dynamo_args.namespace}.prefill.generate"
).client() ).client()
handler = MultimodalWorkerHandler( handler = MultimodalWorkerHandler(
component, engine, config, prefill_client, shutdown_event engine, config, prefill_client, shutdown_event
) )
else: else:
handler = MultimodalWorkerHandler( handler = MultimodalWorkerHandler(engine, config, None, shutdown_event)
component, engine, config, None, shutdown_event
)
await handler.async_init() await handler.async_init()
...@@ -968,10 +958,11 @@ async def init_multimodal_prefill_worker( ...@@ -968,10 +958,11 @@ async def init_multimodal_prefill_worker(
generate_endpoint = runtime.endpoint( generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}" f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component()
handler = MultimodalPrefillWorkerHandler(engine, config, shutdown_event)
shutdown_endpoints[:] = [generate_endpoint] shutdown_endpoints[:] = [generate_endpoint]
handler = MultimodalPrefillWorkerHandler(component, engine, config, shutdown_event)
await handler.async_init() await handler.async_init()
health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict() health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict()
......
...@@ -21,7 +21,7 @@ from dynamo.common.utils.prometheus import ( ...@@ -21,7 +21,7 @@ from dynamo.common.utils.prometheus import (
register_engine_metrics_callback, register_engine_metrics_callback,
) )
from dynamo.llm import KvEventPublisher, WorkerMetricsPublisher from dynamo.llm import KvEventPublisher, WorkerMetricsPublisher
from dynamo.runtime import Component, Endpoint from dynamo.runtime import Endpoint
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
...@@ -63,7 +63,6 @@ class DynamoSglangPublisher: ...@@ -63,7 +63,6 @@ class DynamoSglangPublisher:
self, self,
engine: sgl.Engine, engine: sgl.Engine,
config: Config, config: Config,
component: Component,
generate_endpoint: Endpoint, generate_endpoint: Endpoint,
component_gauges: LLMBackendMetrics, component_gauges: LLMBackendMetrics,
metrics_labels: Optional[List[Tuple[str, str]]] = None, metrics_labels: Optional[List[Tuple[str, str]]] = None,
...@@ -73,7 +72,6 @@ class DynamoSglangPublisher: ...@@ -73,7 +72,6 @@ class DynamoSglangPublisher:
Args: Args:
engine: The SGLang engine instance. engine: The SGLang engine instance.
config: SGLang configuration including server args. config: SGLang configuration including server args.
component: The Dynamo runtime component.
generate_endpoint: The Dynamo endpoint for generation requests. generate_endpoint: The Dynamo endpoint for generation requests.
metrics_labels: Optional list of label key-value pairs for metrics. metrics_labels: Optional list of label key-value pairs for metrics.
component_gauges: LLM backend metrics instance (created via LLMBackendMetrics()). component_gauges: LLM backend metrics instance (created via LLMBackendMetrics()).
...@@ -82,7 +80,6 @@ class DynamoSglangPublisher: ...@@ -82,7 +80,6 @@ class DynamoSglangPublisher:
self.server_args = config.server_args self.server_args = config.server_args
self.dynamo_args = config.dynamo_args self.dynamo_args = config.dynamo_args
self.generate_endpoint = generate_endpoint self.generate_endpoint = generate_endpoint
self.component = component
self.metrics_publisher = WorkerMetricsPublisher() self.metrics_publisher = WorkerMetricsPublisher()
self.component_gauges = component_gauges self.component_gauges = component_gauges
# Endpoint creation is deferred to async context in setup_sgl_metrics # Endpoint creation is deferred to async context in setup_sgl_metrics
...@@ -257,7 +254,7 @@ class DynamoSglangPublisher: ...@@ -257,7 +254,7 @@ class DynamoSglangPublisher:
f"(connecting to {zmq_ep})" f"(connecting to {zmq_ep})"
) )
publisher = KvEventPublisher( publisher = KvEventPublisher(
component=self.component, endpoint=self.generate_endpoint,
kv_block_size=self.server_args.page_size, kv_block_size=self.server_args.page_size,
zmq_endpoint=zmq_ep, zmq_endpoint=zmq_ep,
zmq_topic="", zmq_topic="",
...@@ -322,7 +319,6 @@ def setup_prometheus_registry( ...@@ -322,7 +319,6 @@ def setup_prometheus_registry(
async def setup_sgl_metrics( async def setup_sgl_metrics(
engine: sgl.Engine, engine: sgl.Engine,
config: Config, config: Config,
component: Component,
generate_endpoint: Endpoint, generate_endpoint: Endpoint,
) -> tuple[DynamoSglangPublisher, asyncio.Task, list[tuple[str, str]]]: ) -> tuple[DynamoSglangPublisher, asyncio.Task, list[tuple[str, str]]]:
"""Create publisher, initialize metrics, and start the metrics publishing loop. """Create publisher, initialize metrics, and start the metrics publishing loop.
...@@ -330,7 +326,6 @@ async def setup_sgl_metrics( ...@@ -330,7 +326,6 @@ async def setup_sgl_metrics(
Args: Args:
engine: The SGLang engine instance. engine: The SGLang engine instance.
config: SGLang configuration including server args. config: SGLang configuration including server args.
component: The Dynamo runtime component.
generate_endpoint: The Dynamo endpoint for generation requests. generate_endpoint: The Dynamo endpoint for generation requests.
Returns: Returns:
...@@ -366,13 +361,12 @@ async def setup_sgl_metrics( ...@@ -366,13 +361,12 @@ async def setup_sgl_metrics(
publisher = DynamoSglangPublisher( publisher = DynamoSglangPublisher(
engine, engine,
config, config,
component,
generate_endpoint, generate_endpoint,
component_gauges=component_gauges, component_gauges=component_gauges,
metrics_labels=metrics_labels, metrics_labels=metrics_labels,
) )
# Create endpoint in async context (must await before publishing) # Create endpoint in async context (must await before publishing)
await publisher.metrics_publisher.create_endpoint(component) await publisher.metrics_publisher.create_endpoint(generate_endpoint)
logging.debug("SGLang metrics publisher endpoint created") logging.debug("SGLang metrics publisher endpoint created")
publisher.init_engine_metrics_publish() publisher.init_engine_metrics_publish()
......
...@@ -7,7 +7,7 @@ from typing import Optional ...@@ -7,7 +7,7 @@ from typing import Optional
import sglang as sgl import sglang as sgl
from dynamo._core import Component, Context from dynamo._core import Context
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.protocol import EmbeddingRequest from dynamo.sglang.protocol import EmbeddingRequest
from dynamo.sglang.publisher import DynamoSglangPublisher from dynamo.sglang.publisher import DynamoSglangPublisher
...@@ -17,13 +17,12 @@ from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler ...@@ -17,13 +17,12 @@ from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
class EmbeddingWorkerHandler(BaseWorkerHandler): class EmbeddingWorkerHandler(BaseWorkerHandler):
def __init__( def __init__(
self, self,
component: Component,
engine: sgl.Engine, engine: sgl.Engine,
config: Config, config: Config,
publisher: Optional[DynamoSglangPublisher] = None, publisher: Optional[DynamoSglangPublisher] = None,
shutdown_event: Optional[asyncio.Event] = None, shutdown_event: Optional[asyncio.Event] = None,
): ):
super().__init__(component, engine, config, publisher, None, shutdown_event) super().__init__(engine, config, publisher, None, shutdown_event)
logging.info("Embedding worker handler initialized") logging.info("Embedding worker handler initialized")
def cleanup(self): def cleanup(self):
......
...@@ -13,7 +13,7 @@ from typing import Any, AsyncGenerator, Dict, Optional, Tuple ...@@ -13,7 +13,7 @@ from typing import Any, AsyncGenerator, Dict, Optional, Tuple
import sglang as sgl import sglang as sgl
from sglang.srt.utils import get_local_ip_auto from sglang.srt.utils import get_local_ip_auto
from dynamo._core import Component, Context from dynamo._core import Context
from dynamo.common.utils.input_params import InputParamManager from dynamo.common.utils.input_params import InputParamManager
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher from dynamo.sglang.publisher import DynamoSglangPublisher
...@@ -30,18 +30,15 @@ class BaseGenerativeHandler(ABC): ...@@ -30,18 +30,15 @@ class BaseGenerativeHandler(ABC):
def __init__( def __init__(
self, self,
component: Component,
config: Config, config: Config,
publisher: Optional[DynamoSglangPublisher] = None, publisher: Optional[DynamoSglangPublisher] = None,
) -> None: ) -> None:
"""Initialize base generative handler. """Initialize base generative handler.
Args: Args:
component: The Dynamo runtime component.
config: SGLang and Dynamo configuration. config: SGLang and Dynamo configuration.
publisher: Optional metrics publisher for the worker. publisher: Optional metrics publisher for the worker.
""" """
self.component = component
self.config = config self.config = config
# Set up metrics and KV publishers # Set up metrics and KV publishers
...@@ -98,7 +95,6 @@ class BaseWorkerHandler(BaseGenerativeHandler): ...@@ -98,7 +95,6 @@ class BaseWorkerHandler(BaseGenerativeHandler):
def __init__( def __init__(
self, self,
component: Component,
engine: sgl.Engine, engine: sgl.Engine,
config: Config, config: Config,
publisher: Optional[DynamoSglangPublisher] = None, publisher: Optional[DynamoSglangPublisher] = None,
...@@ -108,7 +104,6 @@ class BaseWorkerHandler(BaseGenerativeHandler): ...@@ -108,7 +104,6 @@ class BaseWorkerHandler(BaseGenerativeHandler):
"""Initialize base worker handler. """Initialize base worker handler.
Args: Args:
component: The Dynamo runtime component.
engine: The SGLang engine instance. engine: The SGLang engine instance.
config: SGLang and Dynamo configuration. config: SGLang and Dynamo configuration.
publisher: Optional metrics publisher for the worker. publisher: Optional metrics publisher for the worker.
...@@ -116,7 +111,7 @@ class BaseWorkerHandler(BaseGenerativeHandler): ...@@ -116,7 +111,7 @@ class BaseWorkerHandler(BaseGenerativeHandler):
shutdown_event: Optional event to signal shutdown. shutdown_event: Optional event to signal shutdown.
""" """
# Call parent constructor # Call parent constructor
super().__init__(component, config, publisher) super().__init__(config, publisher)
# LLM-specific initialization # LLM-specific initialization
self.engine = engine self.engine = engine
......
...@@ -13,7 +13,7 @@ from typing import Any, AsyncGenerator, Optional ...@@ -13,7 +13,7 @@ from typing import Any, AsyncGenerator, Optional
import torch import torch
from PIL import Image from PIL import Image
from dynamo._core import Component, Context from dynamo._core import Context
from dynamo.common.storage import upload_to_fs from dynamo.common.storage import upload_to_fs
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.protocol import CreateImageRequest, ImageData, ImagesResponse, NvExt from dynamo.sglang.protocol import CreateImageRequest, ImageData, ImagesResponse, NvExt
...@@ -34,7 +34,6 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler): ...@@ -34,7 +34,6 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler):
def __init__( def __init__(
self, self,
component: Component,
generator: Any, # DiffGenerator, not sgl.Engine generator: Any, # DiffGenerator, not sgl.Engine
config: Config, config: Config,
publisher: Optional[DynamoSglangPublisher] = None, publisher: Optional[DynamoSglangPublisher] = None,
...@@ -43,13 +42,12 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler): ...@@ -43,13 +42,12 @@ class ImageDiffusionWorkerHandler(BaseGenerativeHandler):
"""Initialize diffusion worker handler. """Initialize diffusion worker handler.
Args: Args:
component: The Dynamo runtime component.
generator: The SGLang DiffGenerator instance. generator: The SGLang DiffGenerator instance.
config: SGLang and Dynamo configuration. config: SGLang and Dynamo configuration.
publisher: Optional metrics publisher (not used for diffusion currently). publisher: Optional metrics publisher (not used for diffusion currently).
fs: Optional fsspec filesystem for primary image storage. fs: Optional fsspec filesystem for primary image storage.
""" """
super().__init__(component, config, publisher) super().__init__(config, publisher)
self.generator = generator # DiffGenerator, not Engine self.generator = generator # DiffGenerator, not Engine
self.fs = fs self.fs = fs
......
...@@ -8,7 +8,7 @@ from typing import Any, AsyncGenerator, Dict, Optional ...@@ -8,7 +8,7 @@ from typing import Any, AsyncGenerator, Dict, Optional
import sglang as sgl import sglang as sgl
from dynamo._core import Component, Context from dynamo._core import Context
from dynamo.common.utils.engine_response import normalize_finish_reason from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.sglang.args import Config, DisaggregationMode from dynamo.sglang.args import Config, DisaggregationMode
from dynamo.sglang.publisher import DynamoSglangPublisher from dynamo.sglang.publisher import DynamoSglangPublisher
...@@ -20,7 +20,6 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -20,7 +20,6 @@ class DecodeWorkerHandler(BaseWorkerHandler):
def __init__( def __init__(
self, self,
component: Component,
engine: sgl.Engine, engine: sgl.Engine,
config: Config, config: Config,
publisher: DynamoSglangPublisher, publisher: DynamoSglangPublisher,
...@@ -30,7 +29,6 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -30,7 +29,6 @@ class DecodeWorkerHandler(BaseWorkerHandler):
"""Initialize decode worker handler. """Initialize decode worker handler.
Args: Args:
component: The Dynamo runtime component.
engine: The SGLang engine instance. engine: The SGLang engine instance.
config: SGLang and Dynamo configuration. config: SGLang and Dynamo configuration.
publisher: Metrics publisher for the worker. publisher: Metrics publisher for the worker.
...@@ -38,7 +36,6 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -38,7 +36,6 @@ class DecodeWorkerHandler(BaseWorkerHandler):
generate_endpoint: The endpoint handle for discovery registration. generate_endpoint: The endpoint handle for discovery registration.
""" """
super().__init__( super().__init__(
component,
engine, engine,
config, config,
publisher, publisher,
......
...@@ -6,7 +6,7 @@ from typing import Any, AsyncGenerator, Dict, Optional ...@@ -6,7 +6,7 @@ from typing import Any, AsyncGenerator, Dict, Optional
import sglang as sgl import sglang as sgl
from dynamo._core import Component, Context from dynamo._core import Context
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.llm.decode_handler import DecodeWorkerHandler from dynamo.sglang.request_handlers.llm.decode_handler import DecodeWorkerHandler
...@@ -19,7 +19,6 @@ class DiffusionWorkerHandler(DecodeWorkerHandler): ...@@ -19,7 +19,6 @@ class DiffusionWorkerHandler(DecodeWorkerHandler):
def __init__( def __init__(
self, self,
component: Component,
engine: sgl.Engine, engine: sgl.Engine,
config: Config, config: Config,
publisher: DynamoSglangPublisher = None, publisher: DynamoSglangPublisher = None,
...@@ -29,16 +28,13 @@ class DiffusionWorkerHandler(DecodeWorkerHandler): ...@@ -29,16 +28,13 @@ class DiffusionWorkerHandler(DecodeWorkerHandler):
"""Initialize diffusion worker handler. """Initialize diffusion worker handler.
Args: Args:
component: The Dynamo runtime component.
engine: SGLang engine with diffusion algorithm configured. engine: SGLang engine with diffusion algorithm configured.
config: SGLang and Dynamo configuration. config: SGLang and Dynamo configuration.
publisher: Optional metrics publisher. publisher: Optional metrics publisher.
generate_endpoint: The endpoint handle for discovery. generate_endpoint: The endpoint handle for discovery.
shutdown_event: Optional event to signal shutdown. shutdown_event: Optional event to signal shutdown.
""" """
super().__init__( super().__init__(engine, config, publisher, generate_endpoint, shutdown_event)
component, engine, config, publisher, generate_endpoint, shutdown_event
)
# Validate that diffusion algorithm is configured # Validate that diffusion algorithm is configured
if ( if (
......
...@@ -7,7 +7,7 @@ from typing import Any, AsyncGenerator, Dict, Optional ...@@ -7,7 +7,7 @@ from typing import Any, AsyncGenerator, Dict, Optional
import sglang as sgl import sglang as sgl
from dynamo._core import Component, Context from dynamo._core import Context
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
...@@ -18,7 +18,6 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -18,7 +18,6 @@ class PrefillWorkerHandler(BaseWorkerHandler):
def __init__( def __init__(
self, self,
component: Component,
engine: sgl.Engine, engine: sgl.Engine,
config: Config, config: Config,
publisher: DynamoSglangPublisher, publisher: DynamoSglangPublisher,
...@@ -28,7 +27,6 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -28,7 +27,6 @@ class PrefillWorkerHandler(BaseWorkerHandler):
"""Initialize prefill worker handler. """Initialize prefill worker handler.
Args: Args:
component: The Dynamo runtime component.
engine: The SGLang engine instance. engine: The SGLang engine instance.
config: SGLang and Dynamo configuration. config: SGLang and Dynamo configuration.
publisher: The SGLang publisher instance. publisher: The SGLang publisher instance.
...@@ -37,9 +35,7 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -37,9 +35,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
""" """
self.engine = engine self.engine = engine
self.bootstrap_host, self.bootstrap_port = self._get_bootstrap_info(self.engine) self.bootstrap_host, self.bootstrap_port = self._get_bootstrap_info(self.engine)
super().__init__( super().__init__(engine, config, publisher, generate_endpoint, shutdown_event)
component, engine, config, publisher, generate_endpoint, shutdown_event
)
self._consume_tasks = set() self._consume_tasks = set()
logging.info( logging.info(
f"Prefill worker handler initialized - bootstrap host: {self.bootstrap_host}, bootstrap port: {self.bootstrap_port}" f"Prefill worker handler initialized - bootstrap host: {self.bootstrap_host}, bootstrap port: {self.bootstrap_port}"
......
...@@ -16,7 +16,7 @@ from sglang.srt.parser.conversation import chat_templates ...@@ -16,7 +16,7 @@ from sglang.srt.parser.conversation import chat_templates
from transformers import AutoTokenizer from transformers import AutoTokenizer
import dynamo.nixl_connect as connect import dynamo.nixl_connect as connect
from dynamo._core import Client, Component, Context from dynamo._core import Client, Context
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.protocol import SglangMultimodalRequest from dynamo.sglang.protocol import SglangMultimodalRequest
...@@ -46,14 +46,11 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler): ...@@ -46,14 +46,11 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
def __init__( def __init__(
self, self,
component: Component,
config: Config, config: Config,
pd_worker_client: Client, pd_worker_client: Client,
shutdown_event: Optional[asyncio.Event] = None, shutdown_event: Optional[asyncio.Event] = None,
) -> None: ) -> None:
super().__init__( super().__init__(engine=None, config=config, shutdown_event=shutdown_event)
component, engine=None, config=config, shutdown_event=shutdown_event
)
self.pd_worker_client = pd_worker_client self.pd_worker_client = pd_worker_client
self.model = config.server_args.model_path self.model = config.server_args.model_path
......
...@@ -10,7 +10,7 @@ from typing import Any, Dict, Optional ...@@ -10,7 +10,7 @@ from typing import Any, Dict, Optional
from transformers import AutoTokenizer from transformers import AutoTokenizer
from dynamo._core import Client, Component, Context from dynamo._core import Client, Context
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.multimodal_utils import ( from dynamo.sglang.multimodal_utils import (
multimodal_request_to_sglang, multimodal_request_to_sglang,
...@@ -34,14 +34,11 @@ class MultimodalProcessorHandler(BaseWorkerHandler): ...@@ -34,14 +34,11 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
def __init__( def __init__(
self, self,
component: Component,
config: Config, config: Config,
encode_worker_client: Client, encode_worker_client: Client,
shutdown_event: Optional[asyncio.Event] = None, shutdown_event: Optional[asyncio.Event] = None,
): ):
super().__init__( super().__init__(engine=None, config=config, shutdown_event=shutdown_event)
component, engine=None, config=config, shutdown_event=shutdown_event
)
self.encode_worker_client = encode_worker_client self.encode_worker_client = encode_worker_client
self.chat_template = getattr(config.server_args, "chat_template", "qwen2-vl") self.chat_template = getattr(config.server_args, "chat_template", "qwen2-vl")
self.model = config.server_args.model_path self.model = config.server_args.model_path
......
...@@ -10,7 +10,7 @@ import sglang as sgl ...@@ -10,7 +10,7 @@ import sglang as sgl
import torch import torch
import dynamo.nixl_connect as connect import dynamo.nixl_connect as connect
from dynamo._core import Client, Component, Context from dynamo._core import Client, Context
from dynamo.common.utils.engine_response import normalize_finish_reason from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.sglang.args import Config, DisaggregationMode from dynamo.sglang.args import Config, DisaggregationMode
from dynamo.sglang.protocol import ( from dynamo.sglang.protocol import (
...@@ -253,13 +253,12 @@ class MultimodalWorkerHandler(BaseWorkerHandler): ...@@ -253,13 +253,12 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
def __init__( def __init__(
self, self,
component: Component,
engine: sgl.Engine, engine: sgl.Engine,
config: Config, config: Config,
prefill_client: Client = None, prefill_client: Client = None,
shutdown_event: Optional[asyncio.Event] = None, shutdown_event: Optional[asyncio.Event] = None,
): ):
super().__init__(component, engine, config, None, None, shutdown_event) super().__init__(engine, config, None, None, shutdown_event)
# Initialize processors # Initialize processors
self.embeddings_processor = EmbeddingsProcessor() self.embeddings_processor = EmbeddingsProcessor()
...@@ -435,12 +434,11 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler): ...@@ -435,12 +434,11 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
def __init__( def __init__(
self, self,
component: Component,
engine: sgl.Engine, engine: sgl.Engine,
config: Config, config: Config,
shutdown_event: Optional[asyncio.Event] = None, shutdown_event: Optional[asyncio.Event] = None,
): ):
super().__init__(component, engine, config, None, None, shutdown_event) super().__init__(engine, config, None, None, shutdown_event)
# Initialize processors # Initialize processors
self.embeddings_processor = EmbeddingsProcessor() self.embeddings_processor = EmbeddingsProcessor()
......
...@@ -11,7 +11,7 @@ from typing import Any, AsyncGenerator, Optional ...@@ -11,7 +11,7 @@ from typing import Any, AsyncGenerator, Optional
import torch import torch
from dynamo._core import Component, Context from dynamo._core import Context
from dynamo.common.storage import upload_to_fs from dynamo.common.storage import upload_to_fs
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.protocol import ( from dynamo.sglang.protocol import (
...@@ -35,7 +35,6 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler): ...@@ -35,7 +35,6 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler):
def __init__( def __init__(
self, self,
component: Component,
generator: Any, # DiffGenerator, not sgl.Engine generator: Any, # DiffGenerator, not sgl.Engine
config: Config, config: Config,
publisher: Optional[DynamoSglangPublisher] = None, publisher: Optional[DynamoSglangPublisher] = None,
...@@ -44,14 +43,13 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler): ...@@ -44,14 +43,13 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler):
"""Initialize video generation worker handler. """Initialize video generation worker handler.
Args: Args:
component: The Dynamo runtime component.
generator: The SGLang DiffGenerator instance. generator: The SGLang DiffGenerator instance.
config: SGLang and Dynamo configuration. config: SGLang and Dynamo configuration.
publisher: Optional metrics publisher (not used for video currently). publisher: Optional metrics publisher (not used for video currently).
fs: Optional fsspec filesystem for primary video storage. fs: Optional fsspec filesystem for primary video storage.
""" """
# Call parent constructor for common setup # Call parent constructor for common setup
super().__init__(component, config, publisher) super().__init__(config, publisher)
# Video generation-specific initialization # Video generation-specific initialization
self.generator = generator # DiffGenerator, not Engine self.generator = generator # DiffGenerator, not Engine
......
...@@ -23,12 +23,6 @@ pytestmark = [ ...@@ -23,12 +23,6 @@ pytestmark = [
] ]
@pytest.fixture
def mock_component():
"""Mock Dynamo Component."""
return MagicMock()
@pytest.fixture @pytest.fixture
def mock_generator(): def mock_generator():
"""Mock SGLang DiffGenerator.""" """Mock SGLang DiffGenerator."""
...@@ -67,12 +61,9 @@ def mock_context(): ...@@ -67,12 +61,9 @@ def mock_context():
@pytest.fixture @pytest.fixture
def handler( def handler(mock_generator, mock_config, mock_fs) -> ImageDiffusionWorkerHandler:
mock_component, mock_generator, mock_config, mock_fs
) -> ImageDiffusionWorkerHandler:
"""Create ImageDiffusionWorkerHandler instance.""" """Create ImageDiffusionWorkerHandler instance."""
return ImageDiffusionWorkerHandler( return ImageDiffusionWorkerHandler(
component=mock_component,
generator=mock_generator, generator=mock_generator,
config=mock_config, config=mock_config,
publisher=None, publisher=None,
...@@ -90,9 +81,7 @@ class TestImageDiffusionWorkerHandler: ...@@ -90,9 +81,7 @@ class TestImageDiffusionWorkerHandler:
assert handler.fs_url == "file:///tmp/images" assert handler.fs_url == "file:///tmp/images"
assert handler.base_url == "file:///tmp/images" assert handler.base_url == "file:///tmp/images"
def test_initialization_with_url_base( def test_initialization_with_url_base(self, mock_generator, mock_fs):
self, mock_component, mock_generator, mock_fs
):
"""Test handler initialization with URL base.""" """Test handler initialization with URL base."""
config = MagicMock() config = MagicMock()
config.dynamo_args = MagicMock() config.dynamo_args = MagicMock()
...@@ -100,7 +89,6 @@ class TestImageDiffusionWorkerHandler: ...@@ -100,7 +89,6 @@ class TestImageDiffusionWorkerHandler:
config.dynamo_args.media_output_http_url = "http://localhost:8008/images" config.dynamo_args.media_output_http_url = "http://localhost:8008/images"
handler = ImageDiffusionWorkerHandler( handler = ImageDiffusionWorkerHandler(
component=mock_component,
generator=mock_generator, generator=mock_generator,
config=config, config=config,
publisher=None, publisher=None,
......
...@@ -296,9 +296,8 @@ class Publisher: ...@@ -296,9 +296,8 @@ class Publisher:
def __init__( def __init__(
self, self,
component, endpoint,
engine, engine,
kv_listener,
worker_id, worker_id,
kv_block_size, kv_block_size,
metrics_labels, metrics_labels,
...@@ -306,9 +305,8 @@ class Publisher: ...@@ -306,9 +305,8 @@ class Publisher:
zmq_endpoint: Optional[str] = None, zmq_endpoint: Optional[str] = None,
enable_local_indexer: bool = False, enable_local_indexer: bool = False,
): ):
self.component = component self.endpoint = endpoint
self.engine = engine self.engine = engine
self.kv_listener = kv_listener
self.worker_id = worker_id self.worker_id = worker_id
self.kv_block_size = kv_block_size self.kv_block_size = kv_block_size
self.max_window_size = None self.max_window_size = None
...@@ -356,7 +354,7 @@ class Publisher: ...@@ -356,7 +354,7 @@ class Publisher:
if self.metrics_publisher is None: if self.metrics_publisher is None:
logging.error("KV metrics publisher not initialized!") logging.error("KV metrics publisher not initialized!")
return return
await self.metrics_publisher.create_endpoint(self.component) await self.metrics_publisher.create_endpoint(self.endpoint)
def initialize(self): def initialize(self):
# Setup the metrics publisher # Setup the metrics publisher
...@@ -385,9 +383,9 @@ class Publisher: ...@@ -385,9 +383,9 @@ class Publisher:
self.kv_event_publishers = {} self.kv_event_publishers = {}
for rank in range(self.attention_dp_size): for rank in range(self.attention_dp_size):
self.kv_event_publishers[rank] = KvEventPublisher( self.kv_event_publishers[rank] = KvEventPublisher(
self.kv_listener, endpoint=self.endpoint,
self.worker_id, worker_id=self.worker_id,
self.kv_block_size, kv_block_size=self.kv_block_size,
dp_rank=rank, dp_rank=rank,
enable_local_indexer=self.enable_local_indexer, enable_local_indexer=self.enable_local_indexer,
) )
...@@ -760,9 +758,8 @@ class Publisher: ...@@ -760,9 +758,8 @@ class Publisher:
@asynccontextmanager @asynccontextmanager
async def get_publisher( async def get_publisher(
component, endpoint,
engine, engine,
kv_listener,
worker_id, worker_id,
kv_block_size, kv_block_size,
metrics_labels, metrics_labels,
...@@ -771,9 +768,8 @@ async def get_publisher( ...@@ -771,9 +768,8 @@ async def get_publisher(
enable_local_indexer: bool = False, enable_local_indexer: bool = False,
): ):
publisher = Publisher( publisher = Publisher(
component, endpoint,
engine, engine,
kv_listener,
worker_id, worker_id,
kv_block_size, kv_block_size,
metrics_labels, metrics_labels,
......
...@@ -55,7 +55,6 @@ class RequestHandlerConfig: ...@@ -55,7 +55,6 @@ class RequestHandlerConfig:
Configuration for the request handler Configuration for the request handler
""" """
component: object
engine: TensorRTLLMEngine engine: TensorRTLLMEngine
default_sampling_params: SamplingParams default_sampling_params: SamplingParams
publisher: Publisher publisher: Publisher
...@@ -89,7 +88,6 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -89,7 +88,6 @@ class HandlerBase(BaseGenerativeHandler):
def __init__(self, config: RequestHandlerConfig): def __init__(self, config: RequestHandlerConfig):
self.engine = config.engine self.engine = config.engine
self.component = config.component
self.default_sampling_params = config.default_sampling_params self.default_sampling_params = config.default_sampling_params
self.publisher = config.publisher self.publisher = config.publisher
self.metrics_collector = config.metrics_collector self.metrics_collector = config.metrics_collector
......
...@@ -13,7 +13,7 @@ import time ...@@ -13,7 +13,7 @@ import time
import uuid import uuid
from typing import Any, AsyncGenerator, Optional from typing import Any, AsyncGenerator, Optional
from dynamo._core import Component, Context from dynamo._core import Context
from dynamo.common.protocols.video_protocol import ( from dynamo.common.protocols.video_protocol import (
NvCreateVideoRequest, NvCreateVideoRequest,
NvVideosResponse, NvVideosResponse,
...@@ -42,18 +42,15 @@ class VideoGenerationHandler(BaseGenerativeHandler): ...@@ -42,18 +42,15 @@ class VideoGenerationHandler(BaseGenerativeHandler):
def __init__( def __init__(
self, self,
component: Component,
engine: DiffusionEngine, engine: DiffusionEngine,
config: DiffusionConfig, config: DiffusionConfig,
): ):
"""Initialize the handler. """Initialize the handler.
Args: Args:
component: The Dynamo runtime component.
engine: The DiffusionEngine instance. engine: The DiffusionEngine instance.
config: Diffusion generation configuration. config: Diffusion generation configuration.
""" """
self.component = component
self.engine = engine self.engine = engine
self.config = config self.config = config
if not config.media_output_fs_url: if not config.media_output_fs_url:
......
...@@ -597,7 +597,6 @@ class TestVideoHandlerConcurrency: ...@@ -597,7 +597,6 @@ class TestVideoHandlerConcurrency:
return_value=MagicMock(), return_value=MagicMock(),
): ):
handler = VideoGenerationHandler( handler = VideoGenerationHandler(
component=MagicMock(),
engine=mock_engine, engine=mock_engine,
config=config, config=config,
) )
...@@ -684,7 +683,6 @@ class TestVideoHandlerResponseFormats: ...@@ -684,7 +683,6 @@ class TestVideoHandlerResponseFormats:
return_value=MagicMock(), return_value=MagicMock(),
): ):
handler = VideoGenerationHandler( handler = VideoGenerationHandler(
component=MagicMock(),
engine=mock_engine, engine=mock_engine,
config=config, config=config,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment