Unverified Commit 2712426f authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

feat: enable mypy in pre-merge (#6732)

parent e5e118a1
...@@ -28,10 +28,6 @@ inputs: ...@@ -28,10 +28,6 @@ inputs:
description: 'Start MinIO service for LoRA tests (true/false)' description: 'Start MinIO service for LoRA tests (true/false)'
required: false required: false
default: 'true' default: 'true'
enable_mypy:
description: 'Enable mypy type checking during test run (true/false)'
required: false
default: 'false'
hf_token: hf_token:
required: false required: false
parallel_mode: parallel_mode:
...@@ -125,12 +121,7 @@ runs: ...@@ -125,12 +121,7 @@ runs:
PYTEST_CMD="pytest -v --collect-only -m \"${{ inputs.pytest_marks }}\"" PYTEST_CMD="pytest -v --collect-only -m \"${{ inputs.pytest_marks }}\""
else else
echo "🚀 Running pytest in normal mode" echo "🚀 Running pytest in normal mode"
MYPY_FLAG="" PYTEST_CMD="pytest --continue-on-collection-errors -v --tb=short --basetemp=/tmp/pytest_temp -o cache_dir=/tmp/.pytest_cache --junitxml=/workspace/test-results/${{ env.PYTEST_XML_FILE }} --durations=20 -m \"${{ inputs.pytest_marks }}\""
if [[ "${{ inputs.enable_mypy }}" == "true" ]]; then
echo "🔍 Mypy type checking enabled"
MYPY_FLAG="--mypy"
fi
PYTEST_CMD="pytest --continue-on-collection-errors -v --tb=short --basetemp=/tmp/pytest_temp -o cache_dir=/tmp/.pytest_cache --junitxml=/workspace/test-results/${{ env.PYTEST_XML_FILE }} --durations=20 ${MYPY_FLAG} -m \"${{ inputs.pytest_marks }}\""
# Detect GPU availability and conditionally add GPU flags # Detect GPU availability and conditionally add GPU flags
GPU_FLAGS="" GPU_FLAGS=""
...@@ -205,12 +196,6 @@ runs: ...@@ -205,12 +196,6 @@ runs:
PYTEST_CMD="pytest -v --collect-only -m \"${{ inputs.pytest_marks }}\"" PYTEST_CMD="pytest -v --collect-only -m \"${{ inputs.pytest_marks }}\""
else else
echo "🚀 Running pytest in normal mode" echo "🚀 Running pytest in normal mode"
MYPY_FLAG=""
if [[ "${{ inputs.enable_mypy }}" == "true" ]]; then
echo "🔍 Mypy type checking enabled"
MYPY_FLAG="--mypy"
fi
# Detect GPU availability and conditionally add GPU flags # Detect GPU availability and conditionally add GPU flags
GPU_FLAGS="" GPU_FLAGS=""
# We check 'docker info' for the 'nvidia' runtime, which indicates the Daemon can spawn GPU containers. # We check 'docker info' for the 'nvidia' runtime, which indicates the Daemon can spawn GPU containers.
...@@ -239,7 +224,7 @@ runs: ...@@ -239,7 +224,7 @@ runs:
# Construct final command with xdist parallelization (-n) and other options # Construct final command with xdist parallelization (-n) and other options
# --dist=loadscope groups tests by module/class to prevent race conditions in stateful tests # --dist=loadscope groups tests by module/class to prevent race conditions in stateful tests
PYTEST_CMD="pytest ${PARALLEL_OPTS} --dist=loadscope --continue-on-collection-errors -v --tb=short --basetemp=/tmp/pytest_temp -o cache_dir=/tmp/.pytest_cache --junitxml=/workspace/test-results/${{ env.PYTEST_XML_FILE }} --durations=10 ${MYPY_FLAG} -m \"${{ inputs.pytest_marks }}\"" PYTEST_CMD="pytest ${PARALLEL_OPTS} --dist=loadscope --continue-on-collection-errors -v --tb=short --basetemp=/tmp/pytest_temp -o cache_dir=/tmp/.pytest_cache --junitxml=/workspace/test-results/${{ env.PYTEST_XML_FILE }} --durations=10 -m \"${{ inputs.pytest_marks }}\""
fi fi
# Get absolute path for test-results directory and ensure it has proper permissions # Get absolute path for test-results directory and ensure it has proper permissions
......
...@@ -274,7 +274,6 @@ jobs: ...@@ -274,7 +274,6 @@ jobs:
framework: ${{ inputs.framework }} framework: ${{ inputs.framework }}
test_type: "pre_merge_cpu" test_type: "pre_merge_cpu"
platform_arch: ${{ inputs.platform }} platform_arch: ${{ inputs.platform }}
enable_mypy: 'true'
hf_token: ${{ secrets.HF_TOKEN }} hf_token: ${{ secrets.HF_TOKEN }}
parallel_mode: 'auto' parallel_mode: 'auto'
dind_as_sidecar: 'true' dind_as_sidecar: 'true'
...@@ -291,7 +290,6 @@ jobs: ...@@ -291,7 +290,6 @@ jobs:
framework: ${{ inputs.framework }} framework: ${{ inputs.framework }}
test_type: "pre_merge_gpu" test_type: "pre_merge_gpu"
platform_arch: ${{ inputs.platform }} platform_arch: ${{ inputs.platform }}
enable_mypy: 'false' # already covered by CPU tests
hf_token: ${{ secrets.HF_TOKEN }} hf_token: ${{ secrets.HF_TOKEN }}
parallel_mode: 'none' parallel_mode: 'none'
dind_as_sidecar: 'true' dind_as_sidecar: 'true'
...@@ -355,7 +353,6 @@ jobs: ...@@ -355,7 +353,6 @@ jobs:
framework: ${{ inputs.framework }} framework: ${{ inputs.framework }}
test_type: "pre_merge_gpu" test_type: "pre_merge_gpu"
platform_arch: ${{ inputs.platform }} platform_arch: ${{ inputs.platform }}
enable_mypy: 'false' # already covered by CPU tests
hf_token: ${{ secrets.HF_TOKEN }} hf_token: ${{ secrets.HF_TOKEN }}
parallel_mode: 'none' parallel_mode: 'none'
dind_as_sidecar: 'true' dind_as_sidecar: 'true'
......
...@@ -47,7 +47,7 @@ jobs: ...@@ -47,7 +47,7 @@ jobs:
dynamo-status-check: dynamo-status-check:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: [changed-files, build, rust-checks, test-parallel, test-sequential] needs: [changed-files, build, rust-checks, mypy, test-parallel, test-sequential]
if: always() if: always()
steps: steps:
- name: "Check all dependent jobs" - name: "Check all dependent jobs"
...@@ -186,10 +186,37 @@ jobs: ...@@ -186,10 +186,37 @@ jobs:
cargo test --locked -p kvbm-physical --features testing-kvbm -- --nocapture --test-threads=4 && \ cargo test --locked -p kvbm-physical --features testing-kvbm -- --nocapture --test-threads=4 && \
/workspace/container/use-sccache.sh show-stats "Rust Checks"' /workspace/container/use-sccache.sh show-stats "Rust Checks"'
test-parallel: mypy:
needs: [changed-files, build] needs: [changed-files, build]
if: needs.changed-files.outputs.core == 'true' || needs.changed-files.outputs.planner == 'true' || needs.changed-files.outputs.frontend == 'true' if: needs.changed-files.outputs.core == 'true' || needs.changed-files.outputs.planner == 'true' || needs.changed-files.outputs.frontend == 'true'
runs-on: prod-builder-amd-v1 runs-on: prod-builder-amd-v1
name: Mypy
timeout-minutes: 15
env:
IMAGE_TAG: ${{ secrets.AWS_ACCOUNT_ID }}.dkr.ecr.${{ secrets.AWS_DEFAULT_REGION }}.amazonaws.com/ai-dynamo/dynamo:${{ needs.build.outputs.test_tag_suffix }}
steps:
- name: Checkout repository
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0
- name: Docker Login
uses: ./.github/actions/docker-login
with:
aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }}
aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }}
- name: Pull test image
run: |
source ./.github/scripts/retry_docker.sh
retry_pull ${{ env.IMAGE_TAG }}
- name: Run mypy
run: |
docker run --rm -w /workspace \
--name mypy_${{ github.run_id }}_${{ github.run_attempt }} \
${{ env.IMAGE_TAG }} \
bash -c 'MYPYPATH=components/src:lib/bindings/python/src mypy --explicit-package-bases components/src/dynamo && MYPYPATH=lib/bindings/python/src mypy -p dynamo'
test-parallel:
needs: [changed-files, build, mypy]
if: needs.changed-files.outputs.core == 'true' || needs.changed-files.outputs.planner == 'true' || needs.changed-files.outputs.frontend == 'true'
runs-on: prod-builder-amd-v1
name: Pytest (parallel) name: Pytest (parallel)
timeout-minutes: 30 timeout-minutes: 30
env: env:
...@@ -214,13 +241,12 @@ jobs: ...@@ -214,13 +241,12 @@ jobs:
framework: dynamo framework: dynamo
test_type: "pre_merge_parallel" test_type: "pre_merge_parallel"
platform_arch: amd64 platform_arch: amd64
enable_mypy: 'true'
hf_token: ${{ secrets.HF_TOKEN }} hf_token: ${{ secrets.HF_TOKEN }}
parallel_mode: '4' parallel_mode: '4'
dind_as_sidecar: 'false' dind_as_sidecar: 'false'
test-sequential: test-sequential:
needs: [changed-files, build] needs: [changed-files, build, mypy]
if: needs.changed-files.outputs.core == 'true' || needs.changed-files.outputs.planner == 'true' || needs.changed-files.outputs.frontend == 'true' if: needs.changed-files.outputs.core == 'true' || needs.changed-files.outputs.planner == 'true' || needs.changed-files.outputs.frontend == 'true'
runs-on: prod-builder-amd-v1 runs-on: prod-builder-amd-v1
name: Pytest (sequential) name: Pytest (sequential)
...@@ -247,7 +273,6 @@ jobs: ...@@ -247,7 +273,6 @@ jobs:
framework: dynamo framework: dynamo
test_type: "pre_merge_sequential" test_type: "pre_merge_sequential"
platform_arch: amd64 platform_arch: amd64
enable_mypy: 'false'
hf_token: ${{ secrets.HF_TOKEN }} hf_token: ${{ secrets.HF_TOKEN }}
parallel_mode: 'none' parallel_mode: 'none'
dind_as_sidecar: 'false' dind_as_sidecar: 'false'
...@@ -66,15 +66,6 @@ repos: ...@@ -66,15 +66,6 @@ repos:
- id: trailing-whitespace - id: trailing-whitespace
exclude: lib/llm/tests/data/deepseek-v3.2/.*\.txt$ exclude: lib/llm/tests/data/deepseek-v3.2/.*\.txt$
# NOTE: removing from pre commit
# will move to gitlab ci to run in proper
# container
#- repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.13.0
# hooks:
# - id: mypy
# exclude: model.py # WAR errors about 'model.py' duplicate module name
# Fast linting # Fast linting
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.2 rev: v0.5.2
......
...@@ -83,7 +83,7 @@ class WelfordAccumulator: ...@@ -83,7 +83,7 @@ class WelfordAccumulator:
class ScheduledRequestMetrics( class ScheduledRequestMetrics(
msgspec.Struct, msgspec.Struct,
frozen=True, frozen=True, # type: ignore[call-arg]
gc=False, gc=False,
): ):
"""Metrics for requests scheduled in this iteration""" """Metrics for requests scheduled in this iteration"""
...@@ -121,7 +121,7 @@ class ScheduledRequestMetrics( ...@@ -121,7 +121,7 @@ class ScheduledRequestMetrics(
class QueuedRequestMetrics( class QueuedRequestMetrics(
msgspec.Struct, msgspec.Struct,
frozen=True, frozen=True, # type: ignore[call-arg]
gc=False, gc=False,
): ):
"""Metrics for requests waiting in the queue (not scheduled this iteration). """Metrics for requests waiting in the queue (not scheduled this iteration).
...@@ -152,7 +152,7 @@ class QueuedRequestMetrics( ...@@ -152,7 +152,7 @@ class QueuedRequestMetrics(
class ForwardPassMetrics( class ForwardPassMetrics(
msgspec.Struct, msgspec.Struct,
frozen=True, frozen=True, # type: ignore[call-arg]
gc=False, gc=False,
): ):
"""Per-iteration metrics emitted by InstrumentedScheduler. """Per-iteration metrics emitted by InstrumentedScheduler.
......
...@@ -3,9 +3,13 @@ ...@@ -3,9 +3,13 @@
"""Multimodal utilities for Dynamo components.""" """Multimodal utilities for Dynamo components."""
from collections.abc import Callable
from dynamo.common.constants import EmbeddingTransferMode from dynamo.common.constants import EmbeddingTransferMode
from dynamo.common.multimodal.async_encoder_cache import AsyncEncoderCache from dynamo.common.multimodal.async_encoder_cache import AsyncEncoderCache
from dynamo.common.multimodal.embedding_transfer import ( from dynamo.common.multimodal.embedding_transfer import (
AbstractEmbeddingReceiver,
AbstractEmbeddingSender,
LocalEmbeddingReceiver, LocalEmbeddingReceiver,
LocalEmbeddingSender, LocalEmbeddingSender,
NixlReadEmbeddingReceiver, NixlReadEmbeddingReceiver,
...@@ -16,13 +20,17 @@ from dynamo.common.multimodal.embedding_transfer import ( ...@@ -16,13 +20,17 @@ from dynamo.common.multimodal.embedding_transfer import (
) )
from dynamo.common.multimodal.image_loader import ImageLoader from dynamo.common.multimodal.image_loader import ImageLoader
EMBEDDING_SENDER_FACTORIES = { EMBEDDING_SENDER_FACTORIES: dict[
EmbeddingTransferMode, Callable[[], AbstractEmbeddingSender]
] = {
EmbeddingTransferMode.LOCAL: LocalEmbeddingSender, EmbeddingTransferMode.LOCAL: LocalEmbeddingSender,
EmbeddingTransferMode.NIXL_WRITE: NixlWriteEmbeddingSender, EmbeddingTransferMode.NIXL_WRITE: NixlWriteEmbeddingSender,
EmbeddingTransferMode.NIXL_READ: NixlReadEmbeddingSender, EmbeddingTransferMode.NIXL_READ: NixlReadEmbeddingSender,
} }
EMBEDDING_RECEIVER_FACTORIES = { EMBEDDING_RECEIVER_FACTORIES: dict[
EmbeddingTransferMode, Callable[[], AbstractEmbeddingReceiver]
] = {
EmbeddingTransferMode.LOCAL: LocalEmbeddingReceiver, EmbeddingTransferMode.LOCAL: LocalEmbeddingReceiver,
EmbeddingTransferMode.NIXL_WRITE: NixlWriteEmbeddingReceiver, EmbeddingTransferMode.NIXL_WRITE: NixlWriteEmbeddingReceiver,
# [gluo FIXME] can't use pre-registered tensor as NIXL requires descriptors # [gluo FIXME] can't use pre-registered tensor as NIXL requires descriptors
......
...@@ -19,9 +19,8 @@ import asyncio ...@@ -19,9 +19,8 @@ import asyncio
import logging import logging
from typing import Awaitable, Callable, Dict, Optional from typing import Awaitable, Callable, Dict, Optional
import torch
from dynamo.common.memory.multimodal_embedding_cache_manager import ( from dynamo.common.memory.multimodal_embedding_cache_manager import (
CachedEmbedding,
MultimodalEmbeddingCacheManager, MultimodalEmbeddingCacheManager,
) )
...@@ -63,9 +62,9 @@ class AsyncEncoderCache: ...@@ -63,9 +62,9 @@ class AsyncEncoderCache:
cache: Underlying MultimodalEmbeddingCacheManager for storage. cache: Underlying MultimodalEmbeddingCacheManager for storage.
""" """
self._cache = cache self._cache = cache
self._in_flight: Dict[str, asyncio.Future[torch.Tensor]] = {} self._in_flight: Dict[str, asyncio.Future[CachedEmbedding]] = {}
def get(self, key: str) -> Optional[torch.Tensor]: def get(self, key: str) -> Optional[CachedEmbedding]:
""" """
Synchronous get from underlying cache. Synchronous get from underlying cache.
...@@ -73,15 +72,15 @@ class AsyncEncoderCache: ...@@ -73,15 +72,15 @@ class AsyncEncoderCache:
key: Cache key. key: Cache key.
Returns: Returns:
Cached tensor or None if not found. Cached embedding or None if not found.
""" """
return self._cache.get(key) return self._cache.get(key)
async def get_or_compute( async def get_or_compute(
self, self,
key: str, key: str,
compute_fn: Callable[[], Awaitable[torch.Tensor]], compute_fn: Callable[[], Awaitable[CachedEmbedding]],
) -> torch.Tensor: ) -> CachedEmbedding:
""" """
Get from cache or compute with request coalescing. Get from cache or compute with request coalescing.
...@@ -91,10 +90,10 @@ class AsyncEncoderCache: ...@@ -91,10 +90,10 @@ class AsyncEncoderCache:
Args: Args:
key: Cache key (typically content hash). key: Cache key (typically content hash).
compute_fn: Async function to compute the tensor if not cached. compute_fn: Async function to compute the embedding if not cached.
Returns: Returns:
The cached or computed tensor. The cached or computed embedding.
Raises: Raises:
Exception: Re-raises any exception from compute_fn. Exception: Re-raises any exception from compute_fn.
...@@ -110,14 +109,14 @@ class AsyncEncoderCache: ...@@ -110,14 +109,14 @@ class AsyncEncoderCache:
return await self._in_flight[key] return await self._in_flight[key]
# Compute with coalescing # Compute with coalescing
future: asyncio.Future[torch.Tensor] = asyncio.Future() future: asyncio.Future[CachedEmbedding] = asyncio.Future()
future.add_done_callback(_suppress_unhandled_future_exception) future.add_done_callback(_suppress_unhandled_future_exception)
self._in_flight[key] = future self._in_flight[key] = future
try: try:
tensor = await compute_fn() embedding = await compute_fn()
self._cache.set(key, tensor) self._cache.set(key, embedding)
future.set_result(tensor) future.set_result(embedding)
return tensor return embedding
except Exception as e: except Exception as e:
future.set_exception(e) future.set_exception(e)
raise raise
......
...@@ -191,6 +191,8 @@ class ImageLoader: ...@@ -191,6 +191,8 @@ class ImageLoader:
elif isinstance(item, dict) and DECODED_VARIANT_KEY in item: elif isinstance(item, dict) and DECODED_VARIANT_KEY in item:
if self._enable_frontend_decoding: if self._enable_frontend_decoding:
metadata = item[DECODED_VARIANT_KEY] metadata = item[DECODED_VARIANT_KEY]
if self._nixl_connector is None:
raise RuntimeError("NIXL connector is not initialized")
image_futures.append( image_futures.append(
read_decoded_media_via_nixl(self._nixl_connector, metadata) read_decoded_media_via_nixl(self._nixl_connector, metadata)
) )
......
...@@ -57,6 +57,7 @@ class TestMultimodalEmbeddingCacheManagerBasicOperations: ...@@ -57,6 +57,7 @@ class TestMultimodalEmbeddingCacheManagerBasicOperations:
cache.set("key1", CachedEmbedding(tensor2)) cache.set("key1", CachedEmbedding(tensor2))
retrieved = cache.get("key1") retrieved = cache.get("key1")
assert retrieved is not None
assert torch.equal(retrieved.tensor, tensor2) assert torch.equal(retrieved.tensor, tensor2)
assert cache.stats["entries"] == 1 assert cache.stats["entries"] == 1
......
...@@ -9,6 +9,7 @@ import pytest ...@@ -9,6 +9,7 @@ import pytest
import torch import torch
from dynamo.common.memory.multimodal_embedding_cache_manager import ( from dynamo.common.memory.multimodal_embedding_cache_manager import (
CachedEmbedding,
MultimodalEmbeddingCacheManager, MultimodalEmbeddingCacheManager,
) )
from dynamo.common.multimodal.async_encoder_cache import AsyncEncoderCache from dynamo.common.multimodal.async_encoder_cache import AsyncEncoderCache
...@@ -30,43 +31,45 @@ class TestAsyncEncoderCacheBasicOperations: ...@@ -30,43 +31,45 @@ class TestAsyncEncoderCacheBasicOperations:
def test_sync_get_returns_cached_tensor(self, cache): def test_sync_get_returns_cached_tensor(self, cache):
"""Test sync get returns tensor after it's cached.""" """Test sync get returns tensor after it's cached."""
tensor = torch.randn(10, 10) tensor = torch.randn(10, 10)
cache._cache.set("key1", tensor) cache._cache.set("key1", CachedEmbedding(tensor))
result = cache.get("key1") result = cache.get("key1")
assert torch.equal(result, tensor) assert result is not None
assert torch.equal(result.tensor, tensor)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_or_compute_caches_result(self, cache): async def test_get_or_compute_caches_result(self, cache):
"""Test get_or_compute caches the computed result.""" """Test get_or_compute caches the computed result."""
tensor = torch.randn(10, 10) tensor = torch.randn(10, 10)
embedding = CachedEmbedding(tensor)
async def compute(): async def compute():
return tensor return embedding
result = await cache.get_or_compute("key1", compute) result = await cache.get_or_compute("key1", compute)
assert torch.equal(result, tensor) assert torch.equal(result.tensor, tensor)
# Should be in cache now # Should be in cache now
cached = cache.get("key1") cached = cache.get("key1")
assert cached is not None assert cached is not None
assert torch.equal(cached, tensor) assert torch.equal(cached.tensor, tensor)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_or_compute_returns_cached(self, cache): async def test_get_or_compute_returns_cached(self, cache):
"""Test get_or_compute returns cached value without computing.""" """Test get_or_compute returns cached value without computing."""
tensor = torch.randn(10, 10) tensor = torch.randn(10, 10)
cache._cache.set("key1", tensor) cache._cache.set("key1", CachedEmbedding(tensor))
compute_called = False compute_called = False
async def compute(): async def compute():
nonlocal compute_called nonlocal compute_called
compute_called = True compute_called = True
return torch.randn(10, 10) return CachedEmbedding(torch.randn(10, 10))
result = await cache.get_or_compute("key1", compute) result = await cache.get_or_compute("key1", compute)
assert torch.equal(result, tensor) assert torch.equal(result.tensor, tensor)
assert not compute_called assert not compute_called
...@@ -84,6 +87,7 @@ class TestAsyncEncoderCacheRequestCoalescing: ...@@ -84,6 +87,7 @@ class TestAsyncEncoderCacheRequestCoalescing:
"""Test that concurrent requests for same key only compute once.""" """Test that concurrent requests for same key only compute once."""
compute_count = 0 compute_count = 0
tensor = torch.randn(10, 10) tensor = torch.randn(10, 10)
embedding = CachedEmbedding(tensor)
compute_started = asyncio.Event() compute_started = asyncio.Event()
compute_proceed = asyncio.Event() compute_proceed = asyncio.Event()
...@@ -92,7 +96,7 @@ class TestAsyncEncoderCacheRequestCoalescing: ...@@ -92,7 +96,7 @@ class TestAsyncEncoderCacheRequestCoalescing:
compute_count += 1 compute_count += 1
compute_started.set() # Signal that compute has started compute_started.set() # Signal that compute has started
await compute_proceed.wait() # Wait for permission to proceed await compute_proceed.wait() # Wait for permission to proceed
return tensor return embedding
# Start concurrent requests as tasks # Start concurrent requests as tasks
task1 = asyncio.create_task(cache.get_or_compute("key1", compute)) task1 = asyncio.create_task(cache.get_or_compute("key1", compute))
...@@ -109,7 +113,7 @@ class TestAsyncEncoderCacheRequestCoalescing: ...@@ -109,7 +113,7 @@ class TestAsyncEncoderCacheRequestCoalescing:
# All should get the same tensor # All should get the same tensor
for result in results: for result in results:
assert torch.equal(result, tensor) assert torch.equal(result.tensor, tensor)
# But compute should only be called once # But compute should only be called once
assert compute_count == 1 assert compute_count == 1
...@@ -122,7 +126,7 @@ class TestAsyncEncoderCacheRequestCoalescing: ...@@ -122,7 +126,7 @@ class TestAsyncEncoderCacheRequestCoalescing:
async def compute(): async def compute():
nonlocal compute_count nonlocal compute_count
compute_count += 1 compute_count += 1
return torch.randn(10, 10) return CachedEmbedding(torch.randn(10, 10))
await asyncio.gather( await asyncio.gather(
cache.get_or_compute("key1", compute), cache.get_or_compute("key1", compute),
...@@ -197,12 +201,13 @@ class TestAsyncEncoderCacheExceptionHandling: ...@@ -197,12 +201,13 @@ class TestAsyncEncoderCacheExceptionHandling:
# Should be able to retry # Should be able to retry
tensor = torch.randn(10, 10) tensor = torch.randn(10, 10)
embedding = CachedEmbedding(tensor)
async def working_compute(): async def working_compute():
return tensor return embedding
result = await cache.get_or_compute("key1", working_compute) result = await cache.get_or_compute("key1", working_compute)
assert torch.equal(result, tensor) assert torch.equal(result.tensor, tensor)
class TestAsyncEncoderCacheStats: class TestAsyncEncoderCacheStats:
...@@ -226,7 +231,7 @@ class TestAsyncEncoderCacheStats: ...@@ -226,7 +231,7 @@ class TestAsyncEncoderCacheStats:
tensor = torch.randn(10, 10) tensor = torch.randn(10, 10)
async def compute(): async def compute():
return tensor return CachedEmbedding(tensor)
await cache.get_or_compute("key1", compute) await cache.get_or_compute("key1", compute)
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import logging import logging
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from typing import Callable from typing import Callable, Optional
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -32,8 +32,8 @@ class Timer: ...@@ -32,8 +32,8 @@ class Timer:
def __init__( def __init__(
self, self,
interval_func: Callable[[float], None] = None, interval_func: Optional[Callable[[float], None]] = None,
stop_func: Callable[[float], None] = None, stop_func: Optional[Callable[[float], None]] = None,
): ):
"""Initialize the Timer and start timing immediately. """Initialize the Timer and start timing immediately.
......
...@@ -292,7 +292,7 @@ class SglangStreamingPostProcessor: ...@@ -292,7 +292,7 @@ class SglangStreamingPostProcessor:
missing_args = any( missing_args = any(
idx not in self._tool_call_args for idx in self._tool_call_names idx not in self._tool_call_args for idx in self._tool_call_names
) )
if missing_args: if missing_args and self.tool_call_parser is not None:
buffer = getattr(self.tool_call_parser.detector, "_buffer", "") buffer = getattr(self.tool_call_parser.detector, "_buffer", "")
if buffer: if buffer:
_, final_calls = self.tool_call_parser.parse_non_stream(buffer) _, final_calls = self.tool_call_parser.parse_non_stream(buffer)
......
...@@ -17,6 +17,7 @@ from typing import Any ...@@ -17,6 +17,7 @@ from typing import Any
from sglang.srt.utils.hf_transformers_utils import get_tokenizer from sglang.srt.utils.hf_transformers_utils import get_tokenizer
from dynamo._core import Client
from dynamo._internal import ModelDeploymentCard from dynamo._internal import ModelDeploymentCard
from dynamo.frontend.frontend_args import FrontendConfig from dynamo.frontend.frontend_args import FrontendConfig
from dynamo.llm import ( from dynamo.llm import (
...@@ -233,7 +234,10 @@ class SglangProcessor: ...@@ -233,7 +234,10 @@ class SglangProcessor:
) -> AsyncGenerator[dict[str, Any], None]: ) -> AsyncGenerator[dict[str, Any], None]:
"""Main entry point: preprocess, route, post-process a chat request.""" """Main entry point: preprocess, route, post-process a chat request."""
if self.debug_perf: if self.debug_perf:
from .perf_instrumentation import enter_generator, exit_generator from .perf_instrumentation import ( # type: ignore[import-not-found, import-untyped]
enter_generator,
exit_generator,
)
active = enter_generator() active = enter_generator()
t_start = time.monotonic() t_start = time.monotonic()
...@@ -320,6 +324,8 @@ class SglangProcessor: ...@@ -320,6 +324,8 @@ class SglangProcessor:
request_id = random_uuid() request_id = random_uuid()
# --- Phase 1: Preprocess (semaphore held) --- # --- Phase 1: Preprocess (semaphore held) ---
assert self._worker_semaphore is not None
assert self.preprocess_pool is not None
try: try:
async with self._worker_semaphore: async with self._worker_semaphore:
future = self.preprocess_pool.submit( future = self.preprocess_pool.submit(
...@@ -543,7 +549,7 @@ class SglangEngineFactory: ...@@ -543,7 +549,7 @@ class SglangEngineFactory:
generate_endpoint = self.runtime.endpoint( generate_endpoint = self.runtime.endpoint(
f"{namespace_name}.{component_name}.{endpoint_name}" f"{namespace_name}.{component_name}.{endpoint_name}"
) )
router: Client | KvRouter
if self.router_config.router_mode == RouterMode.KV: if self.router_config.router_mode == RouterMode.KV:
router = KvRouter( router = KvRouter(
endpoint=generate_endpoint, endpoint=generate_endpoint,
......
...@@ -58,7 +58,7 @@ class ScaleRequestHandler: ...@@ -58,7 +58,7 @@ class ScaleRequestHandler:
self.k8s_namespace = k8s_namespace self.k8s_namespace = k8s_namespace
self.no_operation = no_operation self.no_operation = no_operation
self.max_total_gpus = max_total_gpus self.max_total_gpus = max_total_gpus
self.connectors = {} # Cache of KubernetesConnector per DGD self.connectors: dict[str, KubernetesConnector] = {} # Cache per DGD
# Serializes budget-check + scale-execution so concurrent requests from # Serializes budget-check + scale-execution so concurrent requests from
# different pools cannot both pass against the same pre-scale state. # different pools cannot both pass against the same pre-scale state.
self._scale_lock = asyncio.Lock() self._scale_lock = asyncio.Lock()
......
...@@ -52,6 +52,9 @@ async def worker(runtime: DistributedRuntime): ...@@ -52,6 +52,9 @@ async def worker(runtime: DistributedRuntime):
"""Main worker function for the Global Router service.""" """Main worker function for the Global Router service."""
config = parse_args() config = parse_args()
# validate() ensures these are non-None; assert to narrow types for mypy
assert config.config_path is not None
assert config.model_name is not None
logger.info("Starting Global Router Service") logger.info("Starting Global Router Service")
logger.info(f"Config: {config.config_path}") logger.info(f"Config: {config.config_path}")
logger.info(f"Model name: {config.model_name}") logger.info(f"Model name: {config.model_name}")
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import argparse import argparse
import asyncio import asyncio
import logging import logging
from typing import Union
from pydantic import BaseModel from pydantic import BaseModel
...@@ -39,6 +40,7 @@ class RequestType(BaseModel): ...@@ -39,6 +40,7 @@ class RequestType(BaseModel):
async def start_planner(runtime: DistributedRuntime, config: PlannerConfig): async def start_planner(runtime: DistributedRuntime, config: PlannerConfig):
mode = config.mode mode = config.mode
planner: Union[DisaggPlanner, PrefillPlanner, DecodePlanner, AggPlanner]
if mode == "disagg": if mode == "disagg":
planner = DisaggPlanner(runtime, config) planner = DisaggPlanner(runtime, config)
elif mode == "prefill": elif mode == "prefill":
...@@ -63,7 +65,7 @@ async def init_planner(runtime: DistributedRuntime, config: PlannerConfig): ...@@ -63,7 +65,7 @@ async def init_planner(runtime: DistributedRuntime, config: PlannerConfig):
yield "mock endpoint" yield "mock endpoint"
generate_endpoint = runtime.endpoint(f"{config.namespace}.Planner.generate") generate_endpoint = runtime.endpoint(f"{config.namespace}.Planner.generate")
await generate_endpoint.serve_endpoint(generate) # type: ignore[arg-type] await generate_endpoint.serve_endpoint(generate)
def _parse_config() -> PlannerConfig: def _parse_config() -> PlannerConfig:
......
...@@ -73,7 +73,7 @@ class SLAPlannerDefaults(BasePlannerDefaults): ...@@ -73,7 +73,7 @@ class SLAPlannerDefaults(BasePlannerDefaults):
no_correction = True no_correction = True
mode: Literal["disagg", "prefill", "decode", "agg"] = "disagg" mode: Literal["disagg", "prefill", "decode", "agg"] = "disagg"
throughput_metrics_source = "frontend" # "frontend" | "router" throughput_metrics_source: Literal["frontend", "router"] = "frontend"
# Scaling mode flags # Scaling mode flags
enable_throughput_scaling = True enable_throughput_scaling = True
...@@ -90,7 +90,18 @@ class SLAPlannerDefaults(BasePlannerDefaults): ...@@ -90,7 +90,18 @@ class SLAPlannerDefaults(BasePlannerDefaults):
load_min_observations = 5 # cold start threshold load_min_observations = 5 # cold start threshold
class VllmComponentName: class ComponentName:
"""Base class for backend component name configurations."""
prefill_worker_k8s_name: str = ""
prefill_worker_component_name: str = ""
prefill_worker_endpoint: str = ""
decode_worker_k8s_name: str = ""
decode_worker_component_name: str = ""
decode_worker_endpoint: str = ""
class VllmComponentName(ComponentName):
prefill_worker_k8s_name = "VllmPrefillWorker" prefill_worker_k8s_name = "VllmPrefillWorker"
prefill_worker_component_name = "prefill" prefill_worker_component_name = "prefill"
prefill_worker_endpoint = "generate" prefill_worker_endpoint = "generate"
...@@ -99,7 +110,7 @@ class VllmComponentName: ...@@ -99,7 +110,7 @@ class VllmComponentName:
decode_worker_endpoint = "generate" decode_worker_endpoint = "generate"
class SGLangComponentName: class SGLangComponentName(ComponentName):
prefill_worker_k8s_name = ( prefill_worker_k8s_name = (
"prefill" # use short name to stay within k8s limits with grove "prefill" # use short name to stay within k8s limits with grove
) )
...@@ -112,7 +123,7 @@ class SGLangComponentName: ...@@ -112,7 +123,7 @@ class SGLangComponentName:
decode_worker_endpoint = "generate" decode_worker_endpoint = "generate"
class TrtllmComponentName: class TrtllmComponentName(ComponentName):
# Unified frontend architecture (consistent with vLLM/SGLang): # Unified frontend architecture (consistent with vLLM/SGLang):
# - Prefill workers use "prefill" component # - Prefill workers use "prefill" component
# - Decode workers use "tensorrt_llm" component # - Decode workers use "tensorrt_llm" component
...@@ -124,7 +135,7 @@ class TrtllmComponentName: ...@@ -124,7 +135,7 @@ class TrtllmComponentName:
decode_worker_endpoint = "generate" decode_worker_endpoint = "generate"
class MockerComponentName: class MockerComponentName(ComponentName):
# Mocker backend for testing/simulation purposes # Mocker backend for testing/simulation purposes
prefill_worker_k8s_name = "prefill" prefill_worker_k8s_name = "prefill"
prefill_worker_component_name = "prefill" prefill_worker_component_name = "prefill"
...@@ -134,7 +145,7 @@ class MockerComponentName: ...@@ -134,7 +145,7 @@ class MockerComponentName:
decode_worker_endpoint = "generate" decode_worker_endpoint = "generate"
WORKER_COMPONENT_NAMES = { WORKER_COMPONENT_NAMES: dict[str, type[ComponentName]] = {
"vllm": VllmComponentName, "vllm": VllmComponentName,
"sglang": SGLangComponentName, "sglang": SGLangComponentName,
"trtllm": TrtllmComponentName, "trtllm": TrtllmComponentName,
......
...@@ -222,6 +222,9 @@ class KubernetesConnector(PlannerConnector): ...@@ -222,6 +222,9 @@ class KubernetesConnector(PlannerConnector):
else: else:
raise e raise e
if not model_name:
raise ModelNameNotFoundError()
# If user provided a model name and it doesn't match the model name from the deployment, raise an error # If user provided a model name and it doesn't match the model name from the deployment, raise an error
if self.user_provided_model_name: if self.user_provided_model_name:
if model_name != self.user_provided_model_name: if model_name != self.user_provided_model_name:
...@@ -229,9 +232,6 @@ class KubernetesConnector(PlannerConnector): ...@@ -229,9 +232,6 @@ class KubernetesConnector(PlannerConnector):
model_name, self.user_provided_model_name model_name, self.user_provided_model_name
) )
if not model_name:
raise ModelNameNotFoundError()
return model_name return model_name
def get_gpu_counts( def get_gpu_counts(
......
...@@ -111,13 +111,12 @@ class AggPlanner: ...@@ -111,13 +111,12 @@ class AggPlanner:
logger.info(f"Detected model name from deployment: {model_name}") logger.info(f"Detected model name from deployment: {model_name}")
self.planner.model_name = model_name.lower() self.planner.model_name = model_name.lower()
else: else:
model_name = getattr(self.config, "model_name", None) if not self.config.model_name:
if not model_name:
raise ValueError( raise ValueError(
"Model name is required in no-operation mode. " "Model name is required in no-operation mode. "
"Please set model_name in the config." "Please set model_name in the config."
) )
self.planner.model_name = model_name.lower() self.planner.model_name = self.config.model_name.lower()
loops = [ loops = [
self._load_loop(), self._load_loop(),
......
...@@ -90,12 +90,17 @@ class DecodePlanner(BasePlanner): ...@@ -90,12 +90,17 @@ class DecodePlanner(BasePlanner):
"No decode workers found for correction factor, skipping correction update" "No decode workers found for correction factor, skipping correction update"
) )
return True return True
assert self.last_metrics.num_req is not None
assert self.last_metrics.request_duration is not None
assert self.last_metrics.isl is not None
assert self.last_metrics.osl is not None
assert self.last_metrics.itl is not None
expect_itl = self.decode_interpolator.interpolate_itl( expect_itl = self.decode_interpolator.interpolate_itl(
concurrency=self.last_metrics.num_req # type: ignore concurrency=self.last_metrics.num_req
/ self.shared_state.num_d_workers / self.shared_state.num_d_workers
* self.last_metrics.request_duration # type: ignore * self.last_metrics.request_duration
/ self.config.throughput_adjustment_interval, / self.config.throughput_adjustment_interval,
context_length=self.last_metrics.isl + self.last_metrics.osl / 2, # type: ignore context_length=self.last_metrics.isl + self.last_metrics.osl / 2,
) )
self.d_correction_factor = self.last_metrics.itl / expect_itl self.d_correction_factor = self.last_metrics.itl / expect_itl
logger.info(f"Correction factor (decode ITL): {self.d_correction_factor:.3f}") logger.info(f"Correction factor (decode ITL): {self.d_correction_factor:.3f}")
...@@ -126,6 +131,7 @@ class DecodePlanner(BasePlanner): ...@@ -126,6 +131,7 @@ class DecodePlanner(BasePlanner):
"(no throughput satisfies ITL target), falling back to min_endpoint" "(no throughput satisfies ITL target), falling back to min_endpoint"
) )
return self.config.min_endpoint return self.config.min_endpoint
assert self.config.decode_engine_num_gpu is not None
pred_decode_throughput = ( pred_decode_throughput = (
next_num_req * next_osl / self.config.throughput_adjustment_interval next_num_req * next_osl / self.config.throughput_adjustment_interval
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment