# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import logging import sglang as sgl from utils.protocol import DisaggPreprocessedRequest from utils.sglang import parse_sglang_args from dynamo.sdk import endpoint, service logger = logging.getLogger(__name__) @service( dynamo={ "enabled": True, "namespace": "dynamo", }, resources={"gpu": 1}, workers=1, ) class SGLangDecodeWorker: def __init__(self): class_name = self.__class__.__name__ self.engine_args = parse_sglang_args(class_name, "") self.engine = sgl.Engine(server_args=self.engine_args) logger.warning("Decode worker initialized") @endpoint() async def generate(self, req: DisaggPreprocessedRequest): g = await self.engine.async_generate( input_ids=req.request.token_ids if req.request.batch_token_ids is None else req.request.batch_token_ids, sampling_params=req.sampling_params, stream=True, bootstrap_host=req.bootstrap_host, bootstrap_port=req.bootstrap_port, bootstrap_room=req.bootstrap_room, ) async for result in g: yield result