test_grpc_health.py 4.44 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

grpc = pytest.importorskip("grpc")
health_pb2 = pytest.importorskip("grpc_health.v1.health_pb2")
VllmHealthServicer = pytest.importorskip(
    "smg_grpc_servicer.vllm.health_servicer"
).VllmHealthServicer

SERVING = health_pb2.HealthCheckResponse.SERVING
NOT_SERVING = health_pb2.HealthCheckResponse.NOT_SERVING
SERVICE_UNKNOWN = health_pb2.HealthCheckResponse.SERVICE_UNKNOWN


@pytest.fixture
def async_llm():
    mock = MagicMock()
    mock.check_health = AsyncMock()
    return mock


@pytest.fixture
def context():
    return MagicMock(spec=grpc.aio.ServicerContext)


@pytest.fixture
def servicer(async_llm):
    return VllmHealthServicer(async_llm)


@pytest.fixture
def request_msg():
    msg = MagicMock()
    msg.service = ""
    return msg


# -- Check() tests --


@pytest.mark.asyncio
async def test_check_serving_overall(servicer, request_msg, context, async_llm):
    request_msg.service = ""
    response = await servicer.Check(request_msg, context)
    assert response.status == SERVING
    async_llm.check_health.assert_awaited_once()


@pytest.mark.asyncio
async def test_check_serving_vllm_service(servicer, request_msg, context, async_llm):
    request_msg.service = "vllm.grpc.engine.VllmEngine"
    response = await servicer.Check(request_msg, context)
    assert response.status == SERVING
    async_llm.check_health.assert_awaited_once()


@pytest.mark.asyncio
async def test_check_not_serving_engine_errored(
    servicer, request_msg, context, async_llm
):
    async_llm.check_health = AsyncMock(side_effect=Exception("engine dead"))
    request_msg.service = ""
    response = await servicer.Check(request_msg, context)
    assert response.status == NOT_SERVING


@pytest.mark.asyncio
async def test_check_not_serving_shutting_down(
    servicer, request_msg, context, async_llm
):
    servicer.set_not_serving()
    request_msg.service = ""
    response = await servicer.Check(request_msg, context)
    assert response.status == NOT_SERVING
    async_llm.check_health.assert_not_awaited()


@pytest.mark.asyncio
async def test_check_unknown_service_status(servicer, request_msg, context):
    request_msg.service = "nonexistent.Service"
    response = await servicer.Check(request_msg, context)
    assert response.status == SERVICE_UNKNOWN


@pytest.mark.asyncio
async def test_check_unknown_service_grpc_code(servicer, request_msg, context):
    request_msg.service = "fake.Svc"
    await servicer.Check(request_msg, context)
    context.set_code.assert_called_once_with(grpc.StatusCode.NOT_FOUND)
    context.set_details.assert_called_once()
    details_arg = context.set_details.call_args[0][0]
    assert "fake.Svc" in details_arg


@pytest.mark.asyncio
@patch("smg_grpc_servicer.vllm.health_servicer.logger")
async def test_check_logs_exception_on_error(
    mock_logger, servicer, request_msg, context, async_llm
):
    async_llm.check_health = AsyncMock(side_effect=Exception("engine exploded"))
    request_msg.service = ""
    await servicer.Check(request_msg, context)
    mock_logger.exception.assert_called_once()
    log_args = mock_logger.exception.call_args
    assert "service" in str(log_args).lower()


# -- Watch() tests --


@pytest.mark.asyncio
async def test_watch_yields_serving(servicer, request_msg, context, async_llm):
    request_msg.service = ""
    watch_iter = servicer.Watch(request_msg, context)
    first = await anext(watch_iter.__aiter__())
    assert first.status == SERVING


@pytest.mark.asyncio
async def test_watch_yields_not_serving(servicer, request_msg, context, async_llm):
    async_llm.check_health = AsyncMock(side_effect=Exception("engine down"))
    request_msg.service = ""
    watch_iter = servicer.Watch(request_msg, context)
    first = await anext(watch_iter.__aiter__())
    assert first.status == NOT_SERVING


@pytest.mark.asyncio
async def test_watch_unknown_service(servicer, request_msg, context):
    request_msg.service = "fake.Service"
    results = []
    async for response in servicer.Watch(request_msg, context):
        results.append(response)
    assert len(results) == 1
    assert results[0].status == SERVICE_UNKNOWN
    # Watch returns SERVICE_UNKNOWN in the response body (not as a gRPC error
    # code) so the stream terminates normally -- unlike Check, which sets
    # NOT_FOUND on the gRPC context for unknown services.
    context.set_code.assert_not_called()