planner_service.py 4.28 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import logging

from pydantic import BaseModel

from components.planner import start_planner  # type: ignore[attr-defined]
22
from dynamo.planner.defaults import LoadPlannerDefaults
23
from dynamo.runtime.logging import configure_dynamo_logging
24
from dynamo.sdk import async_on_start, dynamo_context, endpoint, service
25
from dynamo.sdk.core.protocol.interface import ComponentType
26
27
28
29
30
31
32
33
34
35
36
37
38
from dynamo.sdk.lib.config import ServiceConfig
from dynamo.sdk.lib.image import DYNAMO_IMAGE

logger = logging.getLogger(__name__)


class RequestType(BaseModel):
    text: str


@service(
    dynamo={
        "namespace": "dynamo",
39
        "component_type": ComponentType.PLANNER,
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    },
    resources={"cpu": "10", "memory": "20Gi"},
    workers=1,
    image=DYNAMO_IMAGE,
)
class Planner:
    def __init__(self):
        configure_dynamo_logging(service_name="Planner")
        logger.info("Starting planner")
        self.runtime = dynamo_context["runtime"]

        config = ServiceConfig.get_instance()

        # Get namespace directly from dynamo_context as it contains the active namespace
        self.namespace = dynamo_context["namespace"]
55
        config_instance = config.get("Planner", {})
56
57
58

        self.args = argparse.Namespace(
            namespace=self.namespace,
59
60
61
            environment=config_instance.get(
                "environment", LoadPlannerDefaults.environment
            ),
62
            no_operation=config_instance.get(
63
                "no-operation", LoadPlannerDefaults.no_operation
64
            ),
65
            log_dir=config_instance.get("log-dir", LoadPlannerDefaults.log_dir),
66
            adjustment_interval=config_instance.get(
67
                "adjustment-interval", LoadPlannerDefaults.adjustment_interval
68
69
            ),
            metric_pulling_interval=config_instance.get(
70
                "metric-pulling-interval", LoadPlannerDefaults.metric_pulling_interval
71
72
            ),
            max_gpu_budget=config_instance.get(
73
                "max-gpu-budget", LoadPlannerDefaults.max_gpu_budget
74
75
            ),
            min_endpoint=config_instance.get(
76
                "min-endpoint", LoadPlannerDefaults.min_endpoint
77
78
79
            ),
            decode_kv_scale_up_threshold=config_instance.get(
                "decode-kv-scale-up-threshold",
80
                LoadPlannerDefaults.decode_kv_scale_up_threshold,
81
82
83
            ),
            decode_kv_scale_down_threshold=config_instance.get(
                "decode-kv-scale-down-threshold",
84
                LoadPlannerDefaults.decode_kv_scale_down_threshold,
85
86
87
            ),
            prefill_queue_scale_up_threshold=config_instance.get(
                "prefill-queue-scale-up-threshold",
88
                LoadPlannerDefaults.prefill_queue_scale_up_threshold,
89
90
91
            ),
            prefill_queue_scale_down_threshold=config_instance.get(
                "prefill-queue-scale-down-threshold",
92
                LoadPlannerDefaults.prefill_queue_scale_down_threshold,
93
94
            ),
            decode_engine_num_gpu=config_instance.get(
95
                "decode-engine-num-gpu", LoadPlannerDefaults.decode_engine_num_gpu
96
97
            ),
            prefill_engine_num_gpu=config_instance.get(
98
                "prefill-engine-num-gpu", LoadPlannerDefaults.prefill_engine_num_gpu
99
            ),
100
101
102
103
104
105
        )

    @async_on_start
    async def async_init(self):
        import asyncio

106
        await asyncio.sleep(30)
107
108
109
110
        logger.info("Calling start_planner")
        await start_planner(self.runtime, self.args)
        logger.info("Planner started")

111
    @endpoint()
112
113
114
    async def generate(self, request: RequestType):
        """Dummy endpoint to satisfy that each component has an endpoint"""
        yield "mock endpoint"