test_remote_planner.py 10.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Unit tests for remote planner components.

Tests RemotePlannerClient (low-level) and GlobalPlannerConnector (high-level)
for delegating scale requests to GlobalPlanner.
"""

import os
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from dynamo.planner import SubComponentType, TargetReplica
from dynamo.planner.global_planner_connector import GlobalPlannerConnector
from dynamo.planner.remote_planner_client import RemotePlannerClient
from dynamo.planner.scale_protocol import ScaleRequest, ScaleResponse, ScaleStatus
from dynamo.planner.utils.exceptions import EmptyTargetReplicasError

pytestmark = [
    pytest.mark.gpu_0,
    pytest.mark.pre_merge,
    pytest.mark.unit,
    pytest.mark.planner,
]


@pytest.fixture
def mock_runtime():
    """Create a mock DistributedRuntime."""
    runtime = MagicMock()
    endpoint_mock = MagicMock()
    client_mock = AsyncMock()

36
    runtime.endpoint.return_value = endpoint_mock
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    endpoint_mock.client = AsyncMock(return_value=client_mock)
    client_mock.wait_for_instances = AsyncMock()

    # Mock scale_request to return a response
    client_mock.scale_request = AsyncMock(
        return_value={
            "status": "success",
            "message": "Scaled successfully",
            "current_replicas": {"prefill": 3, "decode": 5},
        }
    )

    return runtime, client_mock


@pytest.mark.asyncio
async def test_send_scale_request_success(mock_runtime):
    """Test successful scale request (exercises protocol, client, and serialization)."""
    runtime, mock_client = mock_runtime
    client = RemotePlannerClient(runtime, "central-ns", "Planner")

    request = ScaleRequest(
        caller_namespace="app-ns",
        graph_deployment_name="my-dgd",
        k8s_namespace="default",
        target_replicas=[
            TargetReplica(
                sub_component_type=SubComponentType.PREFILL, desired_replicas=3
            ),
            TargetReplica(
                sub_component_type=SubComponentType.DECODE, desired_replicas=5
            ),
        ],
        blocking=False,
    )

    response = await client.send_scale_request(request)

    assert response.status == ScaleStatus.SUCCESS
    assert response.message == "Scaled successfully"
    assert response.current_replicas["prefill"] == 3
    assert response.current_replicas["decode"] == 5
    # Verify lazy init happened
    assert client._client is not None
81
    runtime.endpoint.assert_called_once_with("central-ns.Planner.scale_request")
82
83
84
85
86
87
88
89
90


@pytest.mark.asyncio
async def test_send_scale_request_error():
    """Test scale request error handling."""
    runtime = MagicMock()
    endpoint_mock = MagicMock()
    client_mock = AsyncMock()

91
    runtime.endpoint.return_value = endpoint_mock
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    endpoint_mock.client = AsyncMock(return_value=client_mock)
    client_mock.wait_for_instances = AsyncMock()

    # Mock scale_request to return error response
    client_mock.scale_request = AsyncMock(
        return_value={
            "status": "error",
            "message": "Namespace not authorized",
            "current_replicas": {},
        }
    )

    client = RemotePlannerClient(runtime, "central-ns", "Planner")

    request = ScaleRequest(
        caller_namespace="unauthorized-ns",
        graph_deployment_name="my-dgd",
        k8s_namespace="default",
        target_replicas=[
            TargetReplica(
                sub_component_type=SubComponentType.PREFILL, desired_replicas=1
            )
        ],
    )

    response = await client.send_scale_request(request)

    assert response.status == ScaleStatus.ERROR
    assert "not authorized" in response.message


@pytest.mark.asyncio
async def test_send_scale_request_no_response():
    """Test scale request when no response is received."""
    runtime = MagicMock()
    endpoint_mock = MagicMock()
    client_mock = AsyncMock()

130
    runtime.endpoint.return_value = endpoint_mock
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
    endpoint_mock.client = AsyncMock(return_value=client_mock)
    client_mock.wait_for_instances = AsyncMock()

    # Mock scale_request to return None
    client_mock.scale_request = AsyncMock(return_value=None)

    client = RemotePlannerClient(runtime, "central-ns", "Planner")

    request = ScaleRequest(
        caller_namespace="app-ns",
        graph_deployment_name="my-dgd",
        k8s_namespace="default",
        target_replicas=[
            TargetReplica(
                sub_component_type=SubComponentType.PREFILL, desired_replicas=1
            )
        ],
    )

    with pytest.raises(RuntimeError, match="No response from centralized planner"):
        await client.send_scale_request(request)


@pytest.mark.asyncio
async def test_multiple_requests_reuse_client(mock_runtime):
    """Test that multiple requests reuse the same client instance."""
    runtime, mock_client = mock_runtime
    client = RemotePlannerClient(runtime, "central-ns", "Planner")

    request1 = ScaleRequest(
        caller_namespace="app-ns",
        graph_deployment_name="my-dgd",
        k8s_namespace="default",
        target_replicas=[
            TargetReplica(
                sub_component_type=SubComponentType.PREFILL, desired_replicas=2
            )
        ],
    )

    request2 = ScaleRequest(
        caller_namespace="app-ns",
        graph_deployment_name="my-dgd",
        k8s_namespace="default",
        target_replicas=[
            TargetReplica(
                sub_component_type=SubComponentType.PREFILL, desired_replicas=4
            )
        ],
    )

    # Send first request
    await client.send_scale_request(request1)
    first_client = client._client

    # Send second request
    await client.send_scale_request(request2)
    second_client = client._client

    # Should be the same client instance
    assert first_client is second_client


# ============================================================================
# GlobalPlannerConnector Tests
# ============================================================================


@pytest.fixture
def connector_runtime():
    """Mock runtime for GlobalPlannerConnector"""
    return MagicMock()


@pytest.fixture
def connector(connector_runtime):
    """Create GlobalPlannerConnector instance"""
    return GlobalPlannerConnector(
        runtime=connector_runtime,
        dynamo_namespace="test-ns",
        global_planner_namespace="global-ns",
        model_name="test-model",
    )


@pytest.mark.asyncio
async def test_connector_initialization(connector, connector_runtime):
    """Test GlobalPlannerConnector initialization and async_init"""
    assert connector.dynamo_namespace == "test-ns"
    assert connector.global_planner_namespace == "global-ns"
    assert connector.remote_client is None

    with patch(
        "dynamo.planner.global_planner_connector.RemotePlannerClient"
    ) as mock_client_class:
        mock_client = MagicMock()
        mock_client_class.return_value = mock_client
        await connector._async_init()
        mock_client_class.assert_called_once_with(
            connector_runtime, "global-ns", "GlobalPlanner"
        )
        assert connector.remote_client == mock_client


@pytest.mark.asyncio
async def test_connector_set_replicas_success(connector):
    """Test GlobalPlannerConnector scaling with enum conversion and predicted load"""
    target_replicas = [
        TargetReplica(
            sub_component_type=SubComponentType.PREFILL,
            component_name="prefill-svc",
            desired_replicas=3,
        ),
        TargetReplica(
            sub_component_type=SubComponentType.DECODE,
            component_name="decode-svc",
            desired_replicas=5,
        ),
    ]

    with patch.dict(
        os.environ, {"DYN_PARENT_DGD_K8S_NAME": "dgd", "POD_NAMESPACE": "ns"}
    ):
        mock_response = ScaleResponse(
            status=ScaleStatus.SUCCESS,
            message="OK",
            current_replicas={"prefill": 3, "decode": 5},
        )
        mock_client = AsyncMock()
        mock_client.send_scale_request = AsyncMock(return_value=mock_response)
        connector.remote_client = mock_client
        connector.set_predicted_load(100.0, 512.0, 256.0)

        await connector.set_component_replicas(target_replicas, blocking=False)

        # Verify request structure and enum to string conversion
        request = mock_client.send_scale_request.call_args[0][0]
        assert request.caller_namespace == "test-ns"
        assert request.blocking is False
        assert request.predicted_load["num_requests"] == 100.0
        assert len(request.target_replicas) == 2
        assert request.target_replicas[0].sub_component_type == "prefill"
        assert isinstance(request.target_replicas[0].sub_component_type, str)


@pytest.mark.asyncio
async def test_connector_error_handling(connector):
    """Test GlobalPlannerConnector error handling"""
    # Empty list
    with pytest.raises(EmptyTargetReplicasError):
        await connector.set_component_replicas([])

    # Uninitialized
    target = [
        TargetReplica(
            sub_component_type=SubComponentType.PREFILL,
            component_name="p",
            desired_replicas=1,
        )
    ]
    with pytest.raises(RuntimeError, match="not initialized"):
        await connector.set_component_replicas(target)

    # Error response
    with patch.dict(os.environ, {"DYN_PARENT_DGD_K8S_NAME": "d", "POD_NAMESPACE": "n"}):
        mock_response = ScaleResponse(
            status=ScaleStatus.ERROR, message="Failed", current_replicas={}
        )
        mock_client = AsyncMock()
        mock_client.send_scale_request = AsyncMock(return_value=mock_response)
        connector.remote_client = mock_client
        with pytest.raises(RuntimeError, match="GlobalPlanner scaling failed"):
            await connector.set_component_replicas(target)


@pytest.mark.asyncio
async def test_connector_unsupported_and_noop_operations(connector):
    """Test unsupported and no-op operations"""
    # Unsupported
    with pytest.raises(NotImplementedError, match="batch operations"):
        await connector.add_component(SubComponentType.PREFILL)
    with pytest.raises(NotImplementedError, match="batch operations"):
        await connector.remove_component(SubComponentType.DECODE)

    # No-op operations
    await connector.validate_deployment(
        prefill_component_name="p", decode_component_name="d"
    )
    await connector.wait_for_deployment_ready()


def test_connector_model_name_and_predicted_load(connector_runtime):
    """Test GlobalPlannerConnector model name and predicted load tracking"""
    # With model name
    c1 = GlobalPlannerConnector(connector_runtime, "ns", "gns", "GP", model_name="test")
    assert c1.get_model_name() == "test"

    # Without model name
    c2 = GlobalPlannerConnector(connector_runtime, "ns", "gns", "GP", model_name=None)
    assert c2.get_model_name() == "managed-remotely"

    # Predicted load
    c1.set_predicted_load(42.0, 256.0, 128.0)
    assert c1.last_predicted_load == {"num_requests": 42.0, "isl": 256.0, "osl": 128.0}