test_kserve_grpc.py 4.94 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
# SPDX-License-Identifier: Apache-2.0

import asyncio
import contextlib
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Optional, Tuple

import pytest
10
11
12
13
14
15
16

try:
    import tritonclient.grpc.model_config_pb2 as mc
    from tritonclient.utils import InferenceServerException
except ImportError:
    mc = None
    InferenceServerException = None
17
18
19

from dynamo.llm import KserveGrpcService, ModelRuntimeConfig, PythonAsyncEngine

20
21
22
23
24
pytestmark = [
    pytest.mark.gpu_0,
    pytest.mark.pre_merge,
    pytest.mark.integration,
]
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80


async def _fetch_model_config(
    client,
    model_name: str,
    retries: int = 30,
) -> Any:
    last_error: Optional[Exception] = None
    for _ in range(retries):
        try:
            return await asyncio.to_thread(client.get_model_config, model_name)
        except InferenceServerException as err:
            last_error = err
            await asyncio.sleep(0.1)
    raise AssertionError(
        f"Unable to fetch model config for '{model_name}': {last_error}"
    )


class EchoTensorEngine:
    """Minimal tensor engine stub for registering tensor models."""

    def __init__(self, model_name: str):
        self._model_name = model_name

    def generate(self, request, context=None):
        async def _generator():
            yield {
                "model": self._model_name,
                "tensors": request.get("tensors", []),
                "parameters": request.get("parameters", {}),
            }

        return _generator()


@pytest.fixture
def tensor_service(runtime):
    @asynccontextmanager
    async def _start(
        model_name: str,
        *,
        runtime_config: Optional[ModelRuntimeConfig] = None,
        checksum: str = "dummy-mdcsum",
    ) -> AsyncIterator[Tuple[str, int]]:
        host = "127.0.0.1"
        port = 8787
        loop = asyncio.get_running_loop()
        engine = PythonAsyncEngine(EchoTensorEngine(model_name).generate, loop)
        tensor_model_service = KserveGrpcService(port=port, host=host)

        tensor_model_service.add_tensor_model(
            model_name, checksum, engine, runtime_config=runtime_config
        )

        async def _serve():
81
            await tensor_model_service.run(runtime)
82
83
84
85
86
87

        server_task = asyncio.create_task(_serve())
        try:
            await asyncio.sleep(1)  # wait service to start
            yield host, port
        finally:
88
            tensor_model_service.shutdown()
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
            with contextlib.suppress(asyncio.TimeoutError, asyncio.CancelledError):
                await asyncio.wait_for(server_task, timeout=5)

    return _start


@pytest.mark.asyncio
@pytest.mark.forked
async def test_model_config_uses_runtime_config(tensor_service):
    """Ensure tensor runtime_config is returned via the ModelConfig endpoint."""
    import tritonclient.grpc as grpcclient

    model_name = "tensor-config-model"
    tensor_config = {
        "name": model_name,
        "inputs": [
            {"name": "input_text", "data_type": "Bytes", "shape": [-1]},
            {"name": "control_flag", "data_type": "Bool", "shape": [1]},
        ],
        "outputs": [
            {"name": "results", "data_type": "Bytes", "shape": [-1]},
        ],
    }
    runtime_config = ModelRuntimeConfig()
    runtime_config.set_tensor_model_config(tensor_config)

    async with tensor_service(model_name, runtime_config=runtime_config) as (
        host,
        port,
    ):
        client = grpcclient.InferenceServerClient(url=f"{host}:{port}")
        try:
            response = await _fetch_model_config(client, model_name)
        finally:
            client.close()

    model_config = response.config
    assert model_config.name == model_name
    assert model_config.platform == "dynamo"
    assert model_config.backend == "dynamo"

    inputs = {spec.name: spec for spec in model_config.input}
    assert list(inputs["input_text"].dims) == [-1]
    assert inputs["input_text"].data_type == mc.TYPE_STRING
    assert list(inputs["control_flag"].dims) == [1]
    assert inputs["control_flag"].data_type == mc.TYPE_BOOL

    outputs = {spec.name: spec for spec in model_config.output}
    assert list(outputs["results"].dims) == [-1]
    assert outputs["results"].data_type == mc.TYPE_STRING


@pytest.mark.asyncio
@pytest.mark.forked
async def test_model_config_missing_runtime_config_errors(tensor_service):
    """ModelConfig should return NOT_FOUND when no tensor runtime_config is saved."""
    model_name = "tensor-config-missing"
    import tritonclient.grpc as grpcclient

    async with tensor_service(model_name, runtime_config=None) as (host, port):
        client = grpcclient.InferenceServerClient(url=f"{host}:{port}")
        try:
            with pytest.raises(InferenceServerException) as excinfo:
                await asyncio.to_thread(client.get_model_config, model_name)
        finally:
            client.close()

    assert "not found" in str(excinfo.value).lower()