Unverified Commit 2712426f authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

feat: enable mypy in pre-merge (#6732)

parent e5e118a1
......@@ -4,7 +4,6 @@
import asyncio
import logging
import time
from typing import Optional
from dynamo.planner import SubComponentType, TargetReplica
from dynamo.planner.utils.decode_planner import DecodePlanner
......@@ -24,9 +23,7 @@ logger = logging.getLogger(__name__)
class DisaggPlanner:
def __init__(
self, runtime: Optional[DistributedRuntime], config: PlannerConfig
) -> None:
def __init__(self, runtime: DistributedRuntime, config: PlannerConfig) -> None:
self.config = config
self.shared_state = PlannerSharedState()
prometheus_metrics = PlannerPrometheusMetrics()
......@@ -89,13 +86,12 @@ class DisaggPlanner:
logger.info(f"Detected model name from deployment: {model_name}")
model_name = model_name.lower()
else:
model_name = getattr(self.config, "model_name", None)
if not model_name:
if not self.config.model_name:
raise ValueError(
"Model name is required in no-operation mode. "
"Please set model_name in the config."
)
model_name = model_name.lower()
model_name = self.config.model_name.lower()
self.prefill_planner.model_name = model_name
self.decode_planner.model_name = model_name
......
......@@ -127,6 +127,13 @@ def run_sla_planner_dryrun(
time_series.append(time_series[-1] + interval)
_est_rr, _est_isl, _est_osl = predictor_planner.predict_load()
# predict_load() returns Optional[float] values; in dryrun mode with
# pre-loaded data the predictors always return valid floats.
assert (
_est_rr is not None and _est_isl is not None and _est_osl is not None
), "predict_load() returned None in dryrun mode"
est_rr.append(_est_rr)
est_isl.append(_est_isl)
est_osl.append(_est_osl)
......@@ -145,10 +152,12 @@ def run_sla_planner_dryrun(
if prefill_planner is not None and decode_planner is not None:
_num_p, _num_d = _apply_global_gpu_budget(_num_p, _num_d, config)
elif prefill_planner is not None:
assert config.prefill_engine_num_gpu is not None
_num_p = _apply_component_gpu_budget(
_num_p, config.prefill_engine_num_gpu, config
)
elif decode_planner is not None:
assert config.decode_engine_num_gpu is not None
_num_d = _apply_component_gpu_budget(
_num_d, config.decode_engine_num_gpu, config
)
......
......@@ -19,7 +19,7 @@ import warnings
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from enum import Enum
from typing import Any
from typing import Any, Callable
import numpy as np
import pandas as pd
......@@ -389,7 +389,7 @@ class KalmanPredictor(BasePredictor):
)
LOAD_PREDICTORS = {
LOAD_PREDICTORS: dict[str, Callable[[PlannerConfig], BasePredictor]] = {
"constant": ConstantPredictor,
"arima": ARIMAPredictor,
"kalman": KalmanPredictor,
......
......@@ -151,6 +151,8 @@ class DecodeInterpolator:
self.resolution = resolution
self.xi = np.linspace(0, 1, resolution)
self.yi = np.linspace(0, max(self.y_context_length), resolution)
self.X: np.ndarray
self.Y: np.ndarray
self.X, self.Y = np.meshgrid(self.xi, self.yi)
# Lazy import scipy only when interpolation is actually needed
......
......@@ -6,7 +6,7 @@ import logging
import math
import time
from dataclasses import dataclass, field
from typing import Optional
from typing import Optional, Union
from prometheus_client import Gauge, start_http_server
......@@ -36,6 +36,9 @@ from dynamo.planner.utils.trace_data_extractor import extract_metrics_from_moonc
from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging
# Union of all connector types used by the planner
ConnectorType = Union[GlobalPlannerConnector, KubernetesConnector, VirtualConnector]
configure_dynamo_logging()
logger = logging.getLogger(__name__)
......@@ -248,14 +251,14 @@ class BasePlanner:
def __init__(
self,
runtime: DistributedRuntime,
runtime: Optional[DistributedRuntime],
config: PlannerConfig,
dryrun: bool = False,
shared_state: Optional[PlannerSharedState] = None,
prometheus_metrics: Optional[PlannerPrometheusMetrics] = None,
prometheus_traffic_client: Optional[PrometheusAPIClient] = None,
prometheus_engine_client: Optional[DirectRouterMetricsClient] = None,
connector=None,
connector: Optional[ConnectorType] = None,
start_prometheus_server: bool = True,
component_type: Optional[SubComponentType] = None,
):
......@@ -272,11 +275,13 @@ class BasePlanner:
if not self.dryrun:
self.runtime = runtime
self.namespace = config.namespace
self.connector: ConnectorType
if not config.no_operation:
# Initialize connector based on environment
if config.environment == "global-planner":
assert config.global_planner_namespace is not None
assert runtime is not None
self.connector = GlobalPlannerConnector(
runtime,
self.namespace,
......@@ -289,6 +294,7 @@ class BasePlanner:
self.namespace, self.model_name
)
elif config.environment == "virtual":
assert runtime is not None
self.connector = VirtualConnector(
runtime,
self.namespace,
......@@ -430,11 +436,12 @@ class BasePlanner:
self.prometheus_engine_client = prometheus_engine_client
else:
# Auto-discover frontend metrics URL in Kubernetes mode
connector = getattr(self, "connector", None)
if not config.load_router_metrics_url and isinstance(
getattr(self, "connector", None), KubernetesConnector
connector, KubernetesConnector
):
config.load_router_metrics_url = (
self.connector.get_frontend_metrics_url()
connector.get_frontend_metrics_url()
)
if not config.load_router_metrics_url:
raise ValueError(
......@@ -447,6 +454,9 @@ class BasePlanner:
f"Auto-discovered frontend metrics URL: {config.load_router_metrics_url}"
)
assert (
config.load_router_metrics_url is not None
), "load_router_metrics_url must be set when load-based scaling is enabled"
self.prometheus_engine_client = DirectRouterMetricsClient(
config.load_router_metrics_url, config.namespace
)
......@@ -494,6 +504,7 @@ class BasePlanner:
async def _get_or_create_client(self, component_name: str, endpoint_name: str):
"""Create a client for the given component and endpoint, with a brief sleep for state sync."""
assert self.runtime is not None, "Runtime is not initialized"
client = await self.runtime.endpoint(
f"{self.namespace}.{component_name}.{endpoint_name}"
).client()
......@@ -604,41 +615,46 @@ class BasePlanner:
)
# Prometheus returns seconds, convert to milliseconds
assert (
self.model_name is not None
), "model_name must be set before observing traffic stats"
interval_str = f"{self.config.throughput_adjustment_interval}s"
self.last_metrics.ttft = (
self.prometheus_traffic_client.get_avg_time_to_first_token(
f"{self.config.throughput_adjustment_interval}s",
interval_str,
self.model_name,
)
* 1000
)
self.last_metrics.itl = (
self.prometheus_traffic_client.get_avg_inter_token_latency(
f"{self.config.throughput_adjustment_interval}s",
interval_str,
self.model_name,
)
* 1000
)
self.last_metrics.num_req = (
self.prometheus_traffic_client.get_avg_request_count(
f"{self.config.throughput_adjustment_interval}s",
interval_str,
self.model_name,
)
)
self.last_metrics.request_duration = (
self.prometheus_traffic_client.get_avg_request_duration(
f"{self.config.throughput_adjustment_interval}s",
interval_str,
self.model_name,
)
)
self.last_metrics.isl = (
self.prometheus_traffic_client.get_avg_input_sequence_tokens(
f"{self.config.throughput_adjustment_interval}s",
interval_str,
self.model_name,
)
)
self.last_metrics.osl = (
self.prometheus_traffic_client.get_avg_output_sequence_tokens(
f"{self.config.throughput_adjustment_interval}s",
interval_str,
self.model_name,
)
)
......@@ -666,8 +682,11 @@ class BasePlanner:
self.update_predictors_from_metrics(self.last_metrics)
def update_predictors_from_metrics(self, metrics: Metrics) -> None:
if metrics.num_req is not None:
self.num_req_predictor.add_data_point(metrics.num_req)
if metrics.isl is not None:
self.isl_predictor.add_data_point(metrics.isl)
if metrics.osl is not None:
self.osl_predictor.add_data_point(metrics.osl)
def predict_load(self) -> tuple[Optional[float], Optional[float], Optional[float]]:
......
......@@ -94,6 +94,7 @@ class PrefillPlanner(BasePlanner):
return None
def _update_correction_factor(self) -> bool:
assert self.last_metrics.isl is not None and self.last_metrics.ttft is not None
expect_ttft = self.prefill_interpolator.interpolate_ttft(self.last_metrics.isl)
self.p_correction_factor = self.last_metrics.ttft / expect_ttft
logger.info(f"Correction factor (prefill TTFT): {self.p_correction_factor:.3f}")
......@@ -117,6 +118,7 @@ class PrefillPlanner(BasePlanner):
"(no throughput satisfies TTFT target), falling back to min_endpoint"
)
return self.config.min_endpoint
assert self.config.prefill_engine_num_gpu is not None
next_num_p = math.ceil(
pred_prefill_throughput
/ p_thpt_per_gpu
......
......@@ -4,7 +4,7 @@
"""Dynamo SGLang wrapper configuration ArgGroup."""
import argparse
from typing import Optional, Union
from typing import Optional
from dynamo.common.configuration.arg_group import ArgGroup
from dynamo.common.configuration.config_base import ConfigBase
......@@ -117,7 +117,7 @@ class DynamoSGLangConfig(ConfigBase):
multimodal_processor: bool
multimodal_encode_worker: bool
multimodal_worker: bool
embedding_transfer_mode: Union[str, EmbeddingTransferMode]
embedding_transfer_mode: EmbeddingTransferMode
embedding_worker: bool
image_diffusion_worker: bool
......@@ -127,10 +127,11 @@ class DynamoSGLangConfig(ConfigBase):
video_generation_worker: bool
def validate(self) -> None:
if isinstance(self.embedding_transfer_mode, str):
if not isinstance(self.embedding_transfer_mode, EmbeddingTransferMode):
self.embedding_transfer_mode = EmbeddingTransferMode(
self.embedding_transfer_mode
str(self.embedding_transfer_mode)
)
if (self.disagg_config is not None) ^ (self.disagg_config_key is not None):
raise ValueError(
"Both 'disagg_config' and 'disagg_config_key' must be provided together."
......
......@@ -8,13 +8,23 @@ import random
import socket
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Optional, Tuple
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Dict,
Generic,
Optional,
Tuple,
TypeVar,
)
import sglang as sgl
from sglang.srt.utils import get_local_ip_auto
from dynamo._core import Context
from dynamo.common.utils.input_params import InputParamManager
from dynamo.llm import KvEventPublisher, WorkerMetricsPublisher
from dynamo.runtime import DistributedRuntime
from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher
......@@ -72,7 +82,11 @@ class SGLangEngineQuiesceController:
self._is_quiesced = False
class BaseGenerativeHandler(ABC):
RequestT = TypeVar("RequestT")
ResponseT = TypeVar("ResponseT")
class BaseGenerativeHandler(ABC, Generic[RequestT, ResponseT]):
"""Minimal base class for all generative handlers (LLM, diffusion, etc.).
Provides common infrastructure for:
......@@ -95,27 +109,24 @@ class BaseGenerativeHandler(ABC):
self.config = config
# Set up metrics and KV publishers
self.metrics_publisher: Optional[WorkerMetricsPublisher] = None
self.kv_publisher: Optional[KvEventPublisher] = None
if publisher is not None:
self.metrics_publisher = publisher.metrics_publisher
self.kv_publisher = publisher.kv_publisher
else:
self.metrics_publisher = None
self.kv_publisher = None
@abstractmethod
async def generate(
self, request: Dict[str, Any], context: Context
) -> AsyncGenerator[Dict[str, Any], None]:
def generate(self, request: RequestT, context: Context) -> AsyncIterator[ResponseT]:
"""Generate response from request.
Args:
request: Request dict with input and parameters.
request: Request with input and parameters.
context: Context object for cancellation handling.
Yields:
Response data (format varies by handler implementation).
"""
pass
...
def cleanup(self) -> None:
"""Cleanup resources. Override in subclasses as needed."""
......@@ -137,7 +148,7 @@ class BaseGenerativeHandler(ABC):
return {"traceparent": f"00-{trace_id}-{span_id}-01"}
class BaseWorkerHandler(BaseGenerativeHandler):
class BaseWorkerHandler(BaseGenerativeHandler[RequestT, ResponseT]):
"""Abstract base class for SGLang LLM worker handlers.
Extends BaseGenerativeHandler with LLM-specific functionality:
......@@ -175,9 +186,6 @@ class BaseWorkerHandler(BaseGenerativeHandler):
if publisher is not None:
self.metrics_publisher = publisher.metrics_publisher
self.kv_publisher = publisher.kv_publisher
else:
self.metrics_publisher = None
self.kv_publisher = None
self.serving_mode = config.serving_mode
self.skip_tokenizer_init = config.server_args.skip_tokenizer_init
self.enable_trace = config.server_args.enable_trace
......@@ -454,17 +462,17 @@ class BaseWorkerHandler(BaseGenerativeHandler):
)
@abstractmethod
async def generate(self, request: Dict[str, Any], context: Context):
def generate(self, request: RequestT, context: Context) -> AsyncIterator[ResponseT]:
"""Generate response from request.
Args:
request: Request dict with input and parameters.
request: Request with input and parameters.
context: Context object for cancellation handling.
Yields:
Response data (format varies by handler implementation).
"""
pass
...
def cleanup(self) -> None:
"""Cleanup resources. Override in subclasses as needed."""
......
......@@ -24,7 +24,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
self,
engine: sgl.Engine,
config: Config,
publisher: DynamoSglangPublisher,
publisher: Optional[DynamoSglangPublisher] = None,
generate_endpoint=None,
shutdown_event: Optional[asyncio.Event] = None,
) -> None:
......@@ -230,7 +230,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
# This lets SGLang proceed to the second token generation, which will
# async context switch and allow the abort monitor to signal cancellation.
# The loop should exit by itself when context.is_stopped() returns True.
out = {}
out: dict[str, Any] = {}
finish_reason = res["meta_info"]["finish_reason"]
if finish_reason:
out["finish_reason"] = normalize_finish_reason(
......
......@@ -21,7 +21,7 @@ class DiffusionWorkerHandler(DecodeWorkerHandler):
self,
engine: sgl.Engine,
config: Config,
publisher: DynamoSglangPublisher = None,
publisher: Optional[DynamoSglangPublisher] = None,
generate_endpoint=None,
shutdown_event: Optional[asyncio.Event] = None,
) -> None:
......
......@@ -38,7 +38,7 @@ except ImportError as e:
DEVICE = "cpu"
class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
class MultimodalEncodeWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]):
"""
Handler for multimodal encode worker component that processes images/videos
and forwards them to the downstream worker.
......@@ -84,12 +84,19 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
if image_token_str == "<|vision_start|><|image_pad|><|vision_end|>":
# These are likely the individual special tokens for Qwen2.5-VL
image_pad_id = self.tokenizer.convert_tokens_to_ids("<|image_pad|>")
assert isinstance(
image_pad_id, int
), f"Expected int token id, got {type(image_pad_id)}"
# Use the image_pad token as the main image token
self.image_token_id = image_pad_id
self.image_token_id: int = image_pad_id
else:
# Fallback for other models
self.image_token_id = self.tokenizer.convert_tokens_to_ids(image_token_str)
token_id = self.tokenizer.convert_tokens_to_ids(image_token_str)
assert isinstance(
token_id, int
), f"Expected int token id, got {type(token_id)}"
self.image_token_id = token_id
self.min_workers = 1
......@@ -230,10 +237,11 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
zip(multimodal_groups, image_grid_thw_list)
):
mm_group.image_grid_thw = image_grid_thw
if mm_group.multimodal_input is not None:
mm_group.multimodal_input.image_url = None
# Store shared tensor transfer metadata at request level.
request.embeddings_shape = tuple(precomputed_embeddings.shape)
request.embeddings_shape = tuple(precomputed_embeddings.shape) # type: ignore[assignment]
request.transfer_payload = None
search_start = 0
......
......@@ -6,7 +6,7 @@ import json
import logging
import time
import uuid
from typing import Any, Dict, Optional
from typing import Any, AsyncGenerator, Dict, Optional
from transformers import AutoTokenizer
......@@ -20,6 +20,7 @@ from dynamo.sglang.protocol import (
MultiModalGroup,
MultiModalInput,
MultiModalRequest,
PreprocessedRequest,
SglangMultimodalRequest,
)
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
......@@ -27,7 +28,7 @@ from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
logger = logging.getLogger(__name__)
class MultimodalProcessorHandler(BaseWorkerHandler):
class MultimodalProcessorHandler(BaseWorkerHandler[MultiModalRequest, Dict[str, Any]]):
"""
Handler for multimodal processor component that processes multimodal requests
and forwards them to the encode worker.
......@@ -56,7 +57,9 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
def cleanup(self):
pass
async def generate(self, raw_request: MultiModalRequest, context: Context):
async def generate(
self, raw_request: MultiModalRequest, context: Context
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Process multimodal request and forward to encode worker.
......@@ -119,7 +122,7 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
)
worker_request = SglangMultimodalRequest(
request=sglang_request,
request=PreprocessedRequest(**sglang_request),
multimodal_inputs=multimodal_groups,
)
......
......@@ -256,7 +256,7 @@ async def _build_mm_items(
return mm_items, embeddings, tensor_id
class MultimodalWorkerHandler(BaseWorkerHandler):
class MultimodalWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]):
"""
Multimodal worker handler for LLM inference with multimodal data.
Handles both aggregated and disaggregated modes.
......@@ -490,7 +490,9 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
logger.info("Multimodal worker engine shutdown")
class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
class MultimodalPrefillWorkerHandler(
BaseWorkerHandler[DisaggSglangMultimodalRequest, str]
):
"""
Multimodal prefill worker handler for disaggregated inference
Processes multimodal inputs and coordinates with decode worker.
......
......@@ -103,16 +103,22 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler):
)
# Parse size
assert req.size is not None, "Size is required"
width, height = self._parse_size(req.size)
# Calculate num_frames if not explicitly provided
num_frames = nvext.num_frames
assert nvext.fps is not None, "FPS is required"
if num_frames is None:
assert req.seconds is not None, "Seconds is required"
num_frames = nvext.fps * req.seconds
# Generate video
context_id = context.id()
assert context_id is not None
assert (
nvext.num_inference_steps is not None
), "Num inference steps is required"
video_bytes = await self._generate_video(
prompt=req.prompt,
width=width,
......
......@@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import List, Optional
from typing import List, Optional, Sequence
import torch
from tensorrt_llm.sampling_params import LogitsProcessor
......@@ -70,7 +70,7 @@ class TrtllmDynamoLogitsAdapter(LogitsProcessor):
def create_trtllm_adapters(
processors: List[BaseLogitsProcessor],
processors: Sequence[BaseLogitsProcessor],
) -> List[TrtllmDynamoLogitsAdapter]:
"""
Create TensorRT-LLM compatible adapters from Dynamo logits processors.
......
......@@ -5,7 +5,9 @@
import logging
from collections.abc import AsyncGenerator
from typing import Optional
from typing import Optional, Union
import torch
from dynamo._core import Context
from dynamo.common.memory.multimodal_embedding_cache_manager import (
......@@ -40,7 +42,7 @@ class AggregatedHandler(HandlerBase):
"""Generate response, optionally using remote encoder for multimodal."""
logging.debug(f"AggregatedHandler Request ID: {context.id()}")
embeddings = None
embeddings: Optional[Union[torch.Tensor, dict]] = None
ep_disaggregated_params = None
if self.multimodal_processor and self.encode_client:
messages = request.get("extra_args", {}).get(
......@@ -57,7 +59,7 @@ class AggregatedHandler(HandlerBase):
self._encoder_cache,
)
if isinstance(result, list):
embeddings = result
embeddings = result # type: ignore[assignment]
else:
ep_disaggregated_params = result
......
......@@ -31,7 +31,7 @@ from tensorrt_llm.llmapi.llm import SamplingParams
from tensorrt_llm.sampling_params import GuidedDecodingParams
from tensorrt_llm.scheduling_params import SchedulingParams
from dynamo._core import Context
from dynamo._core import Client, Context
from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.logits_processing.examples import HelloWorldLogitsProcessor
from dynamo.nixl_connect import Connector
......@@ -65,9 +65,9 @@ class RequestHandlerConfig:
engine: TensorRTLLMEngine
default_sampling_params: SamplingParams
publisher: Publisher
publisher: Optional[Publisher]
disaggregation_mode: DisaggregationMode
encode_client: Optional[object] = None
encode_client: Optional[Client] = None
multimodal_processor: Optional[
MultimodalRequestProcessor
] = None # for multimodal support
......@@ -558,11 +558,11 @@ class HandlerBase(BaseGenerativeHandler):
# PREFILL/ENCODE/AGGREGATED: Process multimodal content if available
if self.multimodal_processor:
processed_input = await self.multimodal_processor.process_openai_request(
mm_result = await self.multimodal_processor.process_openai_request(
request, embeddings, ep_disaggregated_params
)
if processed_input:
return processed_input
if mm_result:
return mm_result
# If multimodal processing returned None but request has multimodal data,
# this is an error (not a text-only request). Raise instead of falling back.
......
......@@ -111,6 +111,8 @@ class PrefillHandler(HandlerBase):
Encoder's embeddings tensor to be used by the prefill worker
"""
# Get response with shape info and readable metadata
if self.encode_client is None:
raise RuntimeError("Encode client is not configured.")
encode_response = None
async for res in await self.encode_client.round_robin(request):
encode_response = res.data()
......@@ -119,6 +121,8 @@ class PrefillHandler(HandlerBase):
if not encode_response:
raise RuntimeError("Did not receive a response from the encode worker.")
if self.connector is None:
raise RuntimeError("Connector is not configured.")
# Use utility function to handle NIXL reading and reconstruction
return await EncodeHelper.read_embeddings_from_encode_response(
encode_response, self.connector
......
......@@ -61,6 +61,9 @@ async def init_video_diffusion_worker(
else []
)
if not config.endpoint:
raise ValueError("endpoint must be configured for video diffusion worker")
# Build DiffusionConfig from the main Config
diffusion_config = DiffusionConfig(
namespace=config.namespace,
......
......@@ -13,7 +13,7 @@ import time
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Final
from typing import Any, AsyncIterator, Dict, Final, Generic, TypeVar
import torch
from vllm.config import VllmConfig
......@@ -23,6 +23,7 @@ from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.engine.exceptions import EngineDeadError
from dynamo._core import Context
from dynamo.common.multimodal.image_loader import ImageLoader
from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.common.utils.input_params import InputParamManager
......@@ -325,7 +326,11 @@ def get_dp_range_for_worker(vllm_config: VllmConfig) -> tuple[int, int]:
)
class BaseWorkerHandler(ABC):
RequestT = TypeVar("RequestT")
ResponseT = TypeVar("ResponseT")
class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
"""
Request handler for the generate and clear_kv_blocks endpoints.
"""
......@@ -459,7 +464,7 @@ class BaseWorkerHandler(ABC):
return {"status": "error", "message": str(e)}
@abstractmethod
async def generate(self, request, context) -> AsyncGenerator[dict, None]:
def generate(self, request: RequestT, context: Context) -> AsyncIterator[ResponseT]:
raise NotImplementedError
async def _monitor_abort(self, context, request_id, is_prefill):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment