test_kserve_grpc.py 4.88 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
# SPDX-License-Identifier: Apache-2.0

import asyncio
import contextlib
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Optional, Tuple

import pytest
10
11
12
13
14
15
16

try:
    import tritonclient.grpc.model_config_pb2 as mc
    from tritonclient.utils import InferenceServerException
except ImportError:
    mc = None
    InferenceServerException = None
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76

from dynamo.llm import KserveGrpcService, ModelRuntimeConfig, PythonAsyncEngine

pytestmark = pytest.mark.pre_merge


async def _fetch_model_config(
    client,
    model_name: str,
    retries: int = 30,
) -> Any:
    last_error: Optional[Exception] = None
    for _ in range(retries):
        try:
            return await asyncio.to_thread(client.get_model_config, model_name)
        except InferenceServerException as err:
            last_error = err
            await asyncio.sleep(0.1)
    raise AssertionError(
        f"Unable to fetch model config for '{model_name}': {last_error}"
    )


class EchoTensorEngine:
    """Minimal tensor engine stub for registering tensor models."""

    def __init__(self, model_name: str):
        self._model_name = model_name

    def generate(self, request, context=None):
        async def _generator():
            yield {
                "model": self._model_name,
                "tensors": request.get("tensors", []),
                "parameters": request.get("parameters", {}),
            }

        return _generator()


@pytest.fixture
def tensor_service(runtime):
    @asynccontextmanager
    async def _start(
        model_name: str,
        *,
        runtime_config: Optional[ModelRuntimeConfig] = None,
        checksum: str = "dummy-mdcsum",
    ) -> AsyncIterator[Tuple[str, int]]:
        host = "127.0.0.1"
        port = 8787
        loop = asyncio.get_running_loop()
        engine = PythonAsyncEngine(EchoTensorEngine(model_name).generate, loop)
        tensor_model_service = KserveGrpcService(port=port, host=host)

        tensor_model_service.add_tensor_model(
            model_name, checksum, engine, runtime_config=runtime_config
        )

        async def _serve():
77
            await tensor_model_service.run(runtime)
78
79
80
81
82
83

        server_task = asyncio.create_task(_serve())
        try:
            await asyncio.sleep(1)  # wait service to start
            yield host, port
        finally:
84
            tensor_model_service.shutdown()
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
            with contextlib.suppress(asyncio.TimeoutError, asyncio.CancelledError):
                await asyncio.wait_for(server_task, timeout=5)

    return _start


@pytest.mark.asyncio
@pytest.mark.forked
async def test_model_config_uses_runtime_config(tensor_service):
    """Ensure tensor runtime_config is returned via the ModelConfig endpoint."""
    import tritonclient.grpc as grpcclient

    model_name = "tensor-config-model"
    tensor_config = {
        "name": model_name,
        "inputs": [
            {"name": "input_text", "data_type": "Bytes", "shape": [-1]},
            {"name": "control_flag", "data_type": "Bool", "shape": [1]},
        ],
        "outputs": [
            {"name": "results", "data_type": "Bytes", "shape": [-1]},
        ],
    }
    runtime_config = ModelRuntimeConfig()
    runtime_config.set_tensor_model_config(tensor_config)

    async with tensor_service(model_name, runtime_config=runtime_config) as (
        host,
        port,
    ):
        client = grpcclient.InferenceServerClient(url=f"{host}:{port}")
        try:
            response = await _fetch_model_config(client, model_name)
        finally:
            client.close()

    model_config = response.config
    assert model_config.name == model_name
    assert model_config.platform == "dynamo"
    assert model_config.backend == "dynamo"

    inputs = {spec.name: spec for spec in model_config.input}
    assert list(inputs["input_text"].dims) == [-1]
    assert inputs["input_text"].data_type == mc.TYPE_STRING
    assert list(inputs["control_flag"].dims) == [1]
    assert inputs["control_flag"].data_type == mc.TYPE_BOOL

    outputs = {spec.name: spec for spec in model_config.output}
    assert list(outputs["results"].dims) == [-1]
    assert outputs["results"].data_type == mc.TYPE_STRING


@pytest.mark.asyncio
@pytest.mark.forked
async def test_model_config_missing_runtime_config_errors(tensor_service):
    """ModelConfig should return NOT_FOUND when no tensor runtime_config is saved."""
    model_name = "tensor-config-missing"
    import tritonclient.grpc as grpcclient

    async with tensor_service(model_name, runtime_config=None) as (host, port):
        client = grpcclient.InferenceServerClient(url=f"{host}:{port}")
        try:
            with pytest.raises(InferenceServerException) as excinfo:
                await asyncio.to_thread(client.get_model_config, model_name)
        finally:
            client.close()

    assert "not found" in str(excinfo.value).lower()