Commit 4c7dceca authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: cleaner required workers check (don't spam print) (#521)

parent 75360111
...@@ -15,12 +15,12 @@ ...@@ -15,12 +15,12 @@
import argparse import argparse
import asyncio
import random import random
from argparse import Namespace from argparse import Namespace
from typing import AsyncIterator from typing import AsyncIterator
from components.worker import VllmWorker from components.worker import VllmWorker
from utils.logging import check_required_workers
from utils.protocol import Tokens from utils.protocol import Tokens
from vllm.logger import logger as vllm_logger from vllm.logger import logger as vllm_logger
...@@ -98,14 +98,8 @@ class Router: ...@@ -98,14 +98,8 @@ class Router:
.endpoint("generate") .endpoint("generate")
.client() .client()
) )
while len(self.workers_client.endpoint_ids()) < self.args.min_workers:
# TODO: replace print w/ vllm_logger.info await check_required_workers(self.workers_client, self.args.min_workers)
print(
f"Waiting for more workers to be ready.\n"
f" Current: {len(self.workers_client.endpoint_ids())},"
f" Required: {self.args.min_workers}"
)
await asyncio.sleep(2)
kv_listener = self.runtime.namespace("dynamo").component("VllmWorker") kv_listener = self.runtime.namespace("dynamo").component("VllmWorker")
await kv_listener.create_service() await kv_listener.create_service()
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import uuid import uuid
from enum import Enum from enum import Enum
from typing import AsyncIterator, Tuple, Union from typing import AsyncIterator, Tuple, Union
...@@ -22,6 +21,7 @@ from components.kv_router import Router ...@@ -22,6 +21,7 @@ from components.kv_router import Router
from components.worker import VllmWorker from components.worker import VllmWorker
from transformers import AutoTokenizer from transformers import AutoTokenizer
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.logging import check_required_workers
from utils.protocol import MyRequestOutput, Tokens, vLLMGenerateRequest from utils.protocol import MyRequestOutput, Tokens, vLLMGenerateRequest
from utils.vllm import parse_vllm_args from utils.vllm import parse_vllm_args
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
...@@ -90,13 +90,8 @@ class Processor(ProcessMixIn): ...@@ -90,13 +90,8 @@ class Processor(ProcessMixIn):
.endpoint("generate") .endpoint("generate")
.client() .client()
) )
while len(self.worker_client.endpoint_ids()) < self.min_workers:
print( await check_required_workers(self.worker_client, self.min_workers)
f"Waiting for workers to be ready.\n"
f" Current: {len(self.worker_client.endpoint_ids())},"
f" Required: {self.min_workers}"
)
await asyncio.sleep(2)
async def _generate( async def _generate(
self, self,
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from dynamo._core import Client
async def check_required_workers(
workers_client: Client, required_workers: int, on_change=True, poll_interval=0.5
):
"""Wait until the minimum number of workers are ready."""
worker_ids = workers_client.endpoint_ids()
num_workers = len(worker_ids)
while num_workers < required_workers:
await asyncio.sleep(poll_interval)
worker_ids = workers_client.endpoint_ids()
new_count = len(worker_ids)
if (not on_change) or new_count != num_workers:
print(
f"Waiting for more workers to be ready.\n"
f" Current: {new_count},"
f" Required: {required_workers}"
)
num_workers = new_count
print(f"Workers ready: {worker_ids}")
return worker_ids
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment