Unverified Commit 2712426f authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

feat: enable mypy in pre-merge (#6732)

parent e5e118a1
......@@ -28,10 +28,6 @@ inputs:
description: 'Start MinIO service for LoRA tests (true/false)'
required: false
default: 'true'
enable_mypy:
description: 'Enable mypy type checking during test run (true/false)'
required: false
default: 'false'
hf_token:
required: false
parallel_mode:
......@@ -125,12 +121,7 @@ runs:
PYTEST_CMD="pytest -v --collect-only -m \"${{ inputs.pytest_marks }}\""
else
echo "🚀 Running pytest in normal mode"
MYPY_FLAG=""
if [[ "${{ inputs.enable_mypy }}" == "true" ]]; then
echo "🔍 Mypy type checking enabled"
MYPY_FLAG="--mypy"
fi
PYTEST_CMD="pytest --continue-on-collection-errors -v --tb=short --basetemp=/tmp/pytest_temp -o cache_dir=/tmp/.pytest_cache --junitxml=/workspace/test-results/${{ env.PYTEST_XML_FILE }} --durations=20 ${MYPY_FLAG} -m \"${{ inputs.pytest_marks }}\""
PYTEST_CMD="pytest --continue-on-collection-errors -v --tb=short --basetemp=/tmp/pytest_temp -o cache_dir=/tmp/.pytest_cache --junitxml=/workspace/test-results/${{ env.PYTEST_XML_FILE }} --durations=20 -m \"${{ inputs.pytest_marks }}\""
# Detect GPU availability and conditionally add GPU flags
GPU_FLAGS=""
......@@ -205,12 +196,6 @@ runs:
PYTEST_CMD="pytest -v --collect-only -m \"${{ inputs.pytest_marks }}\""
else
echo "🚀 Running pytest in normal mode"
MYPY_FLAG=""
if [[ "${{ inputs.enable_mypy }}" == "true" ]]; then
echo "🔍 Mypy type checking enabled"
MYPY_FLAG="--mypy"
fi
# Detect GPU availability and conditionally add GPU flags
GPU_FLAGS=""
# We check 'docker info' for the 'nvidia' runtime, which indicates the Daemon can spawn GPU containers.
......@@ -239,7 +224,7 @@ runs:
# Construct final command with xdist parallelization (-n) and other options
# --dist=loadscope groups tests by module/class to prevent race conditions in stateful tests
PYTEST_CMD="pytest ${PARALLEL_OPTS} --dist=loadscope --continue-on-collection-errors -v --tb=short --basetemp=/tmp/pytest_temp -o cache_dir=/tmp/.pytest_cache --junitxml=/workspace/test-results/${{ env.PYTEST_XML_FILE }} --durations=10 ${MYPY_FLAG} -m \"${{ inputs.pytest_marks }}\""
PYTEST_CMD="pytest ${PARALLEL_OPTS} --dist=loadscope --continue-on-collection-errors -v --tb=short --basetemp=/tmp/pytest_temp -o cache_dir=/tmp/.pytest_cache --junitxml=/workspace/test-results/${{ env.PYTEST_XML_FILE }} --durations=10 -m \"${{ inputs.pytest_marks }}\""
fi
# Get absolute path for test-results directory and ensure it has proper permissions
......
......@@ -274,7 +274,6 @@ jobs:
framework: ${{ inputs.framework }}
test_type: "pre_merge_cpu"
platform_arch: ${{ inputs.platform }}
enable_mypy: 'true'
hf_token: ${{ secrets.HF_TOKEN }}
parallel_mode: 'auto'
dind_as_sidecar: 'true'
......@@ -291,7 +290,6 @@ jobs:
framework: ${{ inputs.framework }}
test_type: "pre_merge_gpu"
platform_arch: ${{ inputs.platform }}
enable_mypy: 'false' # already covered by CPU tests
hf_token: ${{ secrets.HF_TOKEN }}
parallel_mode: 'none'
dind_as_sidecar: 'true'
......@@ -355,7 +353,6 @@ jobs:
framework: ${{ inputs.framework }}
test_type: "pre_merge_gpu"
platform_arch: ${{ inputs.platform }}
enable_mypy: 'false' # already covered by CPU tests
hf_token: ${{ secrets.HF_TOKEN }}
parallel_mode: 'none'
dind_as_sidecar: 'true'
......
......@@ -47,7 +47,7 @@ jobs:
dynamo-status-check:
runs-on: ubuntu-latest
needs: [changed-files, build, rust-checks, test-parallel, test-sequential]
needs: [changed-files, build, rust-checks, mypy, test-parallel, test-sequential]
if: always()
steps:
- name: "Check all dependent jobs"
......@@ -186,10 +186,37 @@ jobs:
cargo test --locked -p kvbm-physical --features testing-kvbm -- --nocapture --test-threads=4 && \
/workspace/container/use-sccache.sh show-stats "Rust Checks"'
test-parallel:
mypy:
needs: [changed-files, build]
if: needs.changed-files.outputs.core == 'true' || needs.changed-files.outputs.planner == 'true' || needs.changed-files.outputs.frontend == 'true'
runs-on: prod-builder-amd-v1
name: Mypy
timeout-minutes: 15
env:
IMAGE_TAG: ${{ secrets.AWS_ACCOUNT_ID }}.dkr.ecr.${{ secrets.AWS_DEFAULT_REGION }}.amazonaws.com/ai-dynamo/dynamo:${{ needs.build.outputs.test_tag_suffix }}
steps:
- name: Checkout repository
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
- name: Docker Login
uses: ./.github/actions/docker-login
with:
aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }}
aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }}
- name: Pull test image
run: |
source ./.github/scripts/retry_docker.sh
retry_pull ${{ env.IMAGE_TAG }}
- name: Run mypy
run: |
docker run --rm -w /workspace \
--name mypy_${{ github.run_id }}_${{ github.run_attempt }} \
${{ env.IMAGE_TAG }} \
bash -c 'MYPYPATH=components/src:lib/bindings/python/src mypy --explicit-package-bases components/src/dynamo && MYPYPATH=lib/bindings/python/src mypy -p dynamo'
test-parallel:
needs: [changed-files, build, mypy]
if: needs.changed-files.outputs.core == 'true' || needs.changed-files.outputs.planner == 'true' || needs.changed-files.outputs.frontend == 'true'
runs-on: prod-builder-amd-v1
name: Pytest (parallel)
timeout-minutes: 30
env:
......@@ -214,13 +241,12 @@ jobs:
framework: dynamo
test_type: "pre_merge_parallel"
platform_arch: amd64
enable_mypy: 'true'
hf_token: ${{ secrets.HF_TOKEN }}
parallel_mode: '4'
dind_as_sidecar: 'false'
test-sequential:
needs: [changed-files, build]
needs: [changed-files, build, mypy]
if: needs.changed-files.outputs.core == 'true' || needs.changed-files.outputs.planner == 'true' || needs.changed-files.outputs.frontend == 'true'
runs-on: prod-builder-amd-v1
name: Pytest (sequential)
......@@ -247,7 +273,6 @@ jobs:
framework: dynamo
test_type: "pre_merge_sequential"
platform_arch: amd64
enable_mypy: 'false'
hf_token: ${{ secrets.HF_TOKEN }}
parallel_mode: 'none'
dind_as_sidecar: 'false'
......@@ -66,15 +66,6 @@ repos:
- id: trailing-whitespace
exclude: lib/llm/tests/data/deepseek-v3.2/.*\.txt$
# NOTE: removing from pre commit
# will move to gitlab ci to run in proper
# container
#- repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.13.0
# hooks:
# - id: mypy
# exclude: model.py # WAR errors about 'model.py' duplicate module name
# Fast linting
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.2
......
......@@ -83,7 +83,7 @@ class WelfordAccumulator:
class ScheduledRequestMetrics(
msgspec.Struct,
frozen=True,
frozen=True, # type: ignore[call-arg]
gc=False,
):
"""Metrics for requests scheduled in this iteration"""
......@@ -121,7 +121,7 @@ class ScheduledRequestMetrics(
class QueuedRequestMetrics(
msgspec.Struct,
frozen=True,
frozen=True, # type: ignore[call-arg]
gc=False,
):
"""Metrics for requests waiting in the queue (not scheduled this iteration).
......@@ -152,7 +152,7 @@ class QueuedRequestMetrics(
class ForwardPassMetrics(
msgspec.Struct,
frozen=True,
frozen=True, # type: ignore[call-arg]
gc=False,
):
"""Per-iteration metrics emitted by InstrumentedScheduler.
......
......@@ -3,9 +3,13 @@
"""Multimodal utilities for Dynamo components."""
from collections.abc import Callable
from dynamo.common.constants import EmbeddingTransferMode
from dynamo.common.multimodal.async_encoder_cache import AsyncEncoderCache
from dynamo.common.multimodal.embedding_transfer import (
AbstractEmbeddingReceiver,
AbstractEmbeddingSender,
LocalEmbeddingReceiver,
LocalEmbeddingSender,
NixlReadEmbeddingReceiver,
......@@ -16,13 +20,17 @@ from dynamo.common.multimodal.embedding_transfer import (
)
from dynamo.common.multimodal.image_loader import ImageLoader
EMBEDDING_SENDER_FACTORIES = {
EMBEDDING_SENDER_FACTORIES: dict[
EmbeddingTransferMode, Callable[[], AbstractEmbeddingSender]
] = {
EmbeddingTransferMode.LOCAL: LocalEmbeddingSender,
EmbeddingTransferMode.NIXL_WRITE: NixlWriteEmbeddingSender,
EmbeddingTransferMode.NIXL_READ: NixlReadEmbeddingSender,
}
EMBEDDING_RECEIVER_FACTORIES = {
EMBEDDING_RECEIVER_FACTORIES: dict[
EmbeddingTransferMode, Callable[[], AbstractEmbeddingReceiver]
] = {
EmbeddingTransferMode.LOCAL: LocalEmbeddingReceiver,
EmbeddingTransferMode.NIXL_WRITE: NixlWriteEmbeddingReceiver,
# [gluo FIXME] can't use pre-registered tensor as NIXL requires descriptors
......
......@@ -19,9 +19,8 @@ import asyncio
import logging
from typing import Awaitable, Callable, Dict, Optional
import torch
from dynamo.common.memory.multimodal_embedding_cache_manager import (
CachedEmbedding,
MultimodalEmbeddingCacheManager,
)
......@@ -63,9 +62,9 @@ class AsyncEncoderCache:
cache: Underlying MultimodalEmbeddingCacheManager for storage.
"""
self._cache = cache
self._in_flight: Dict[str, asyncio.Future[torch.Tensor]] = {}
self._in_flight: Dict[str, asyncio.Future[CachedEmbedding]] = {}
def get(self, key: str) -> Optional[torch.Tensor]:
def get(self, key: str) -> Optional[CachedEmbedding]:
"""
Synchronous get from underlying cache.
......@@ -73,15 +72,15 @@ class AsyncEncoderCache:
key: Cache key.
Returns:
Cached tensor or None if not found.
Cached embedding or None if not found.
"""
return self._cache.get(key)
async def get_or_compute(
self,
key: str,
compute_fn: Callable[[], Awaitable[torch.Tensor]],
) -> torch.Tensor:
compute_fn: Callable[[], Awaitable[CachedEmbedding]],
) -> CachedEmbedding:
"""
Get from cache or compute with request coalescing.
......@@ -91,10 +90,10 @@ class AsyncEncoderCache:
Args:
key: Cache key (typically content hash).
compute_fn: Async function to compute the tensor if not cached.
compute_fn: Async function to compute the embedding if not cached.
Returns:
The cached or computed tensor.
The cached or computed embedding.
Raises:
Exception: Re-raises any exception from compute_fn.
......@@ -110,14 +109,14 @@ class AsyncEncoderCache:
return await self._in_flight[key]
# Compute with coalescing
future: asyncio.Future[torch.Tensor] = asyncio.Future()
future: asyncio.Future[CachedEmbedding] = asyncio.Future()
future.add_done_callback(_suppress_unhandled_future_exception)
self._in_flight[key] = future
try:
tensor = await compute_fn()
self._cache.set(key, tensor)
future.set_result(tensor)
return tensor
embedding = await compute_fn()
self._cache.set(key, embedding)
future.set_result(embedding)
return embedding
except Exception as e:
future.set_exception(e)
raise
......
......@@ -191,6 +191,8 @@ class ImageLoader:
elif isinstance(item, dict) and DECODED_VARIANT_KEY in item:
if self._enable_frontend_decoding:
metadata = item[DECODED_VARIANT_KEY]
if self._nixl_connector is None:
raise RuntimeError("NIXL connector is not initialized")
image_futures.append(
read_decoded_media_via_nixl(self._nixl_connector, metadata)
)
......
......@@ -57,6 +57,7 @@ class TestMultimodalEmbeddingCacheManagerBasicOperations:
cache.set("key1", CachedEmbedding(tensor2))
retrieved = cache.get("key1")
assert retrieved is not None
assert torch.equal(retrieved.tensor, tensor2)
assert cache.stats["entries"] == 1
......
......@@ -9,6 +9,7 @@ import pytest
import torch
from dynamo.common.memory.multimodal_embedding_cache_manager import (
CachedEmbedding,
MultimodalEmbeddingCacheManager,
)
from dynamo.common.multimodal.async_encoder_cache import AsyncEncoderCache
......@@ -30,43 +31,45 @@ class TestAsyncEncoderCacheBasicOperations:
def test_sync_get_returns_cached_tensor(self, cache):
"""Test sync get returns tensor after it's cached."""
tensor = torch.randn(10, 10)
cache._cache.set("key1", tensor)
cache._cache.set("key1", CachedEmbedding(tensor))
result = cache.get("key1")
assert torch.equal(result, tensor)
assert result is not None
assert torch.equal(result.tensor, tensor)
@pytest.mark.asyncio
async def test_get_or_compute_caches_result(self, cache):
"""Test get_or_compute caches the computed result."""
tensor = torch.randn(10, 10)
embedding = CachedEmbedding(tensor)
async def compute():
return tensor
return embedding
result = await cache.get_or_compute("key1", compute)
assert torch.equal(result, tensor)
assert torch.equal(result.tensor, tensor)
# Should be in cache now
cached = cache.get("key1")
assert cached is not None
assert torch.equal(cached, tensor)
assert torch.equal(cached.tensor, tensor)
@pytest.mark.asyncio
async def test_get_or_compute_returns_cached(self, cache):
"""Test get_or_compute returns cached value without computing."""
tensor = torch.randn(10, 10)
cache._cache.set("key1", tensor)
cache._cache.set("key1", CachedEmbedding(tensor))
compute_called = False
async def compute():
nonlocal compute_called
compute_called = True
return torch.randn(10, 10)
return CachedEmbedding(torch.randn(10, 10))
result = await cache.get_or_compute("key1", compute)
assert torch.equal(result, tensor)
assert torch.equal(result.tensor, tensor)
assert not compute_called
......@@ -84,6 +87,7 @@ class TestAsyncEncoderCacheRequestCoalescing:
"""Test that concurrent requests for same key only compute once."""
compute_count = 0
tensor = torch.randn(10, 10)
embedding = CachedEmbedding(tensor)
compute_started = asyncio.Event()
compute_proceed = asyncio.Event()
......@@ -92,7 +96,7 @@ class TestAsyncEncoderCacheRequestCoalescing:
compute_count += 1
compute_started.set() # Signal that compute has started
await compute_proceed.wait() # Wait for permission to proceed
return tensor
return embedding
# Start concurrent requests as tasks
task1 = asyncio.create_task(cache.get_or_compute("key1", compute))
......@@ -109,7 +113,7 @@ class TestAsyncEncoderCacheRequestCoalescing:
# All should get the same tensor
for result in results:
assert torch.equal(result, tensor)
assert torch.equal(result.tensor, tensor)
# But compute should only be called once
assert compute_count == 1
......@@ -122,7 +126,7 @@ class TestAsyncEncoderCacheRequestCoalescing:
async def compute():
nonlocal compute_count
compute_count += 1
return torch.randn(10, 10)
return CachedEmbedding(torch.randn(10, 10))
await asyncio.gather(
cache.get_or_compute("key1", compute),
......@@ -197,12 +201,13 @@ class TestAsyncEncoderCacheExceptionHandling:
# Should be able to retry
tensor = torch.randn(10, 10)
embedding = CachedEmbedding(tensor)
async def working_compute():
return tensor
return embedding
result = await cache.get_or_compute("key1", working_compute)
assert torch.equal(result, tensor)
assert torch.equal(result.tensor, tensor)
class TestAsyncEncoderCacheStats:
......@@ -226,7 +231,7 @@ class TestAsyncEncoderCacheStats:
tensor = torch.randn(10, 10)
async def compute():
return tensor
return CachedEmbedding(tensor)
await cache.get_or_compute("key1", compute)
......
......@@ -4,7 +4,7 @@
import logging
import time
from contextlib import contextmanager
from typing import Callable
from typing import Callable, Optional
logger = logging.getLogger(__name__)
......@@ -32,8 +32,8 @@ class Timer:
def __init__(
self,
interval_func: Callable[[float], None] = None,
stop_func: Callable[[float], None] = None,
interval_func: Optional[Callable[[float], None]] = None,
stop_func: Optional[Callable[[float], None]] = None,
):
"""Initialize the Timer and start timing immediately.
......
......@@ -292,7 +292,7 @@ class SglangStreamingPostProcessor:
missing_args = any(
idx not in self._tool_call_args for idx in self._tool_call_names
)
if missing_args:
if missing_args and self.tool_call_parser is not None:
buffer = getattr(self.tool_call_parser.detector, "_buffer", "")
if buffer:
_, final_calls = self.tool_call_parser.parse_non_stream(buffer)
......
......@@ -17,6 +17,7 @@ from typing import Any
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
from dynamo._core import Client
from dynamo._internal import ModelDeploymentCard
from dynamo.frontend.frontend_args import FrontendConfig
from dynamo.llm import (
......@@ -233,7 +234,10 @@ class SglangProcessor:
) -> AsyncGenerator[dict[str, Any], None]:
"""Main entry point: preprocess, route, post-process a chat request."""
if self.debug_perf:
from .perf_instrumentation import enter_generator, exit_generator
from .perf_instrumentation import ( # type: ignore[import-not-found, import-untyped]
enter_generator,
exit_generator,
)
active = enter_generator()
t_start = time.monotonic()
......@@ -320,6 +324,8 @@ class SglangProcessor:
request_id = random_uuid()
# --- Phase 1: Preprocess (semaphore held) ---
assert self._worker_semaphore is not None
assert self.preprocess_pool is not None
try:
async with self._worker_semaphore:
future = self.preprocess_pool.submit(
......@@ -543,7 +549,7 @@ class SglangEngineFactory:
generate_endpoint = self.runtime.endpoint(
f"{namespace_name}.{component_name}.{endpoint_name}"
)
router: Client | KvRouter
if self.router_config.router_mode == RouterMode.KV:
router = KvRouter(
endpoint=generate_endpoint,
......
......@@ -58,7 +58,7 @@ class ScaleRequestHandler:
self.k8s_namespace = k8s_namespace
self.no_operation = no_operation
self.max_total_gpus = max_total_gpus
self.connectors = {} # Cache of KubernetesConnector per DGD
self.connectors: dict[str, KubernetesConnector] = {} # Cache per DGD
# Serializes budget-check + scale-execution so concurrent requests from
# different pools cannot both pass against the same pre-scale state.
self._scale_lock = asyncio.Lock()
......
......@@ -52,6 +52,9 @@ async def worker(runtime: DistributedRuntime):
"""Main worker function for the Global Router service."""
config = parse_args()
# validate() ensures these are non-None; assert to narrow types for mypy
assert config.config_path is not None
assert config.model_name is not None
logger.info("Starting Global Router Service")
logger.info(f"Config: {config.config_path}")
logger.info(f"Model name: {config.model_name}")
......
......@@ -16,6 +16,7 @@
import argparse
import asyncio
import logging
from typing import Union
from pydantic import BaseModel
......@@ -39,6 +40,7 @@ class RequestType(BaseModel):
async def start_planner(runtime: DistributedRuntime, config: PlannerConfig):
mode = config.mode
planner: Union[DisaggPlanner, PrefillPlanner, DecodePlanner, AggPlanner]
if mode == "disagg":
planner = DisaggPlanner(runtime, config)
elif mode == "prefill":
......@@ -63,7 +65,7 @@ async def init_planner(runtime: DistributedRuntime, config: PlannerConfig):
yield "mock endpoint"
generate_endpoint = runtime.endpoint(f"{config.namespace}.Planner.generate")
await generate_endpoint.serve_endpoint(generate) # type: ignore[arg-type]
await generate_endpoint.serve_endpoint(generate)
def _parse_config() -> PlannerConfig:
......
......@@ -73,7 +73,7 @@ class SLAPlannerDefaults(BasePlannerDefaults):
no_correction = True
mode: Literal["disagg", "prefill", "decode", "agg"] = "disagg"
throughput_metrics_source = "frontend" # "frontend" | "router"
throughput_metrics_source: Literal["frontend", "router"] = "frontend"
# Scaling mode flags
enable_throughput_scaling = True
......@@ -90,7 +90,18 @@ class SLAPlannerDefaults(BasePlannerDefaults):
load_min_observations = 5 # cold start threshold
class VllmComponentName:
class ComponentName:
"""Base class for backend component name configurations."""
prefill_worker_k8s_name: str = ""
prefill_worker_component_name: str = ""
prefill_worker_endpoint: str = ""
decode_worker_k8s_name: str = ""
decode_worker_component_name: str = ""
decode_worker_endpoint: str = ""
class VllmComponentName(ComponentName):
prefill_worker_k8s_name = "VllmPrefillWorker"
prefill_worker_component_name = "prefill"
prefill_worker_endpoint = "generate"
......@@ -99,7 +110,7 @@ class VllmComponentName:
decode_worker_endpoint = "generate"
class SGLangComponentName:
class SGLangComponentName(ComponentName):
prefill_worker_k8s_name = (
"prefill" # use short name to stay within k8s limits with grove
)
......@@ -112,7 +123,7 @@ class SGLangComponentName:
decode_worker_endpoint = "generate"
class TrtllmComponentName:
class TrtllmComponentName(ComponentName):
# Unified frontend architecture (consistent with vLLM/SGLang):
# - Prefill workers use "prefill" component
# - Decode workers use "tensorrt_llm" component
......@@ -124,7 +135,7 @@ class TrtllmComponentName:
decode_worker_endpoint = "generate"
class MockerComponentName:
class MockerComponentName(ComponentName):
# Mocker backend for testing/simulation purposes
prefill_worker_k8s_name = "prefill"
prefill_worker_component_name = "prefill"
......@@ -134,7 +145,7 @@ class MockerComponentName:
decode_worker_endpoint = "generate"
WORKER_COMPONENT_NAMES = {
WORKER_COMPONENT_NAMES: dict[str, type[ComponentName]] = {
"vllm": VllmComponentName,
"sglang": SGLangComponentName,
"trtllm": TrtllmComponentName,
......
......@@ -222,6 +222,9 @@ class KubernetesConnector(PlannerConnector):
else:
raise e
if not model_name:
raise ModelNameNotFoundError()
# If user provided a model name and it doesn't match the model name from the deployment, raise an error
if self.user_provided_model_name:
if model_name != self.user_provided_model_name:
......@@ -229,9 +232,6 @@ class KubernetesConnector(PlannerConnector):
model_name, self.user_provided_model_name
)
if not model_name:
raise ModelNameNotFoundError()
return model_name
def get_gpu_counts(
......
......@@ -111,13 +111,12 @@ class AggPlanner:
logger.info(f"Detected model name from deployment: {model_name}")
self.planner.model_name = model_name.lower()
else:
model_name = getattr(self.config, "model_name", None)
if not model_name:
if not self.config.model_name:
raise ValueError(
"Model name is required in no-operation mode. "
"Please set model_name in the config."
)
self.planner.model_name = model_name.lower()
self.planner.model_name = self.config.model_name.lower()
loops = [
self._load_loop(),
......
......@@ -90,12 +90,17 @@ class DecodePlanner(BasePlanner):
"No decode workers found for correction factor, skipping correction update"
)
return True
assert self.last_metrics.num_req is not None
assert self.last_metrics.request_duration is not None
assert self.last_metrics.isl is not None
assert self.last_metrics.osl is not None
assert self.last_metrics.itl is not None
expect_itl = self.decode_interpolator.interpolate_itl(
concurrency=self.last_metrics.num_req # type: ignore
concurrency=self.last_metrics.num_req
/ self.shared_state.num_d_workers
* self.last_metrics.request_duration # type: ignore
* self.last_metrics.request_duration
/ self.config.throughput_adjustment_interval,
context_length=self.last_metrics.isl + self.last_metrics.osl / 2, # type: ignore
context_length=self.last_metrics.isl + self.last_metrics.osl / 2,
)
self.d_correction_factor = self.last_metrics.itl / expect_itl
logger.info(f"Correction factor (decode ITL): {self.d_correction_factor:.3f}")
......@@ -126,6 +131,7 @@ class DecodePlanner(BasePlanner):
"(no throughput satisfies ITL target), falling back to min_endpoint"
)
return self.config.min_endpoint
assert self.config.decode_engine_num_gpu is not None
pred_decode_throughput = (
next_num_req * next_osl / self.config.throughput_adjustment_interval
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment