"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "022afbeb4efa22bb8a4656a2712cd66c6a811c23"
Unverified Commit 90d74637 authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

fix: properly setup and register vLLM worker for external / hybrid load...


fix: properly setup and register vLLM worker for external / hybrid load balancing. Update launch script (#6695)
Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
parent 9254e3d4
...@@ -16,6 +16,7 @@ from dataclasses import dataclass ...@@ -16,6 +16,7 @@ from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Final from typing import Any, AsyncGenerator, Dict, Final
import torch import torch
from vllm.config import VllmConfig
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
...@@ -260,6 +261,32 @@ def build_sampling_params_openai( ...@@ -260,6 +261,32 @@ def build_sampling_params_openai(
return sampling_params return sampling_params
def get_dp_range_for_worker(vllm_config: VllmConfig) -> range:
"""
Get the global DP rank range that this worker is responsible for based on vLLM config.
Note that the 'vllm_config' is normalized so the load balancing flags are set properly.
The return value is in the format of (start_dp_rank, managed_dp_size)."""
if vllm_config.parallel_config.data_parallel_external_lb:
# external load balancing, each worker is responsible for exactly 1 rank
return (vllm_config.parallel_config.data_parallel_rank, 1)
elif vllm_config.parallel_config.data_parallel_hybrid_lb:
# hybrid load balancing, each worker is responsible for a subset of local ranks
return (
vllm_config.parallel_config.data_parallel_rank,
vllm_config.parallel_config.data_parallel_size_local,
)
else:
# internal load balancing, the worker is responsible for all DP ranks
logger.warning(
"vLLM selects internal DP load balancing. If you are launching multiple workers for DP deployment,"
" hybrid or external load balancing is recommended."
)
return (
vllm_config.parallel_config.data_parallel_rank,
vllm_config.parallel_config.data_parallel_size,
)
class BaseWorkerHandler(ABC): class BaseWorkerHandler(ABC):
""" """
Request handler for the generate and clear_kv_blocks endpoints. Request handler for the generate and clear_kv_blocks endpoints.
...@@ -302,6 +329,8 @@ class BaseWorkerHandler(ABC): ...@@ -302,6 +329,8 @@ class BaseWorkerHandler(ABC):
self.use_vllm_tokenizer = use_vllm_tokenizer self.use_vllm_tokenizer = use_vllm_tokenizer
self.dp_range = get_dp_range_for_worker(self.engine_client.vllm_config)
# Initialize InputParamManager for text-in-text-out mode # Initialize InputParamManager for text-in-text-out mode
tokenizer = None tokenizer = None
if use_vllm_tokenizer and hasattr(engine, "tokenizer"): if use_vllm_tokenizer and hasattr(engine, "tokenizer"):
...@@ -463,6 +492,21 @@ class BaseWorkerHandler(ABC): ...@@ -463,6 +492,21 @@ class BaseWorkerHandler(ABC):
if temp_dir is not None: if temp_dir is not None:
self.temp_dirs.append(temp_dir) self.temp_dirs.append(temp_dir)
def _to_local_dp_rank(self, dp_rank: int | None) -> int | None:
"""Convert global DP rank to local DP rank based on engine config."""
if dp_rank is None:
return None
if dp_rank < self.dp_range[0] or dp_rank >= self.dp_range[0] + self.dp_range[1]:
logger.warning(
f"Received DP rank {dp_rank} is out of range [{self.dp_range[0]} - {self.dp_range[0] + self.dp_range[1]}), fallback to vLLM internal DP selection"
)
return None
local_dp_rank = (dp_rank - self.dp_range[0]) % self.dp_range[1]
logger.debug(
f"Converted global DP rank {dp_rank} to local DP rank {local_dp_rank}"
)
return local_dp_rank
def _resolve_lora_request(self, model_name: str | None) -> LoRARequest | None: def _resolve_lora_request(self, model_name: str | None) -> LoRARequest | None:
"""Return a LoRARequest if model_name is a loaded adapter, else None.""" """Return a LoRARequest if model_name is a loaded adapter, else None."""
if model_name and (lora := self.loaded_loras.get(model_name)): if model_name and (lora := self.loaded_loras.get(model_name)):
...@@ -1317,7 +1361,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -1317,7 +1361,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
f"Decode request {request_id} has no LoRA specified (model: {model_name})" f"Decode request {request_id} has no LoRA specified (model: {model_name})"
) )
routing = request.get("routing") or {} routing = request.get("routing") or {}
dp_rank = routing.get("dp_rank") dp_rank = self._to_local_dp_rank(routing.get("dp_rank"))
priority = routing.get("priority", 0) priority = routing.get("priority", 0)
trace_headers = build_trace_headers(context) trace_headers = build_trace_headers(context)
...@@ -1364,7 +1408,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -1364,7 +1408,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
) )
routing = request.get("routing") or {} routing = request.get("routing") or {}
dp_rank = routing.get("dp_rank") dp_rank = self._to_local_dp_rank(routing.get("dp_rank"))
priority = routing.get("priority", 0) priority = routing.get("priority", 0)
openai_request_id = request.get("id") or request.get("request_id", request_id) openai_request_id = request.get("id") or request.get("request_id", request_id)
previous_text = "" previous_text = ""
...@@ -1525,7 +1569,7 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -1525,7 +1569,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
) )
routing = request.get("routing") or {} routing = request.get("routing") or {}
dp_rank = routing.get("dp_rank") dp_rank = self._to_local_dp_rank(routing.get("dp_rank"))
priority = routing.get("priority", 0) priority = routing.get("priority", 0)
trace_headers = build_trace_headers(context) trace_headers = build_trace_headers(context)
......
...@@ -55,7 +55,7 @@ from dynamo.vllm.worker_factory import WorkerFactory ...@@ -55,7 +55,7 @@ from dynamo.vllm.worker_factory import WorkerFactory
from .args import Config, _uses_dynamo_connector, parse_args from .args import Config, _uses_dynamo_connector, parse_args
from .checkpoint_restore import get_checkpoint_config from .checkpoint_restore import get_checkpoint_config
from .constants import DisaggregationMode from .constants import DisaggregationMode
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler from .handlers import DecodeWorkerHandler, PrefillWorkerHandler, get_dp_range_for_worker
from .health_check import ( from .health_check import (
VllmHealthCheckPayload, VllmHealthCheckPayload,
VllmOmniHealthCheckPayload, VllmOmniHealthCheckPayload,
...@@ -69,19 +69,6 @@ shutdown_endpoints: list = [] ...@@ -69,19 +69,6 @@ shutdown_endpoints: list = []
CHECKPOINT_SLEEP_MODE_LEVEL = 1 CHECKPOINT_SLEEP_MODE_LEVEL = 1
async def _handle_non_leader_node(dp_rank: int) -> None:
"""
Handle non-leader node (data_parallel_rank >= 1) in multi-node deployments.
Non-leader nodes run vLLM workers but don't serve Dynamo endpoints.
"""
logger.info(
f"Non-leader node detected (data_parallel_rank={dp_rank}). "
"Skipping endpoint serving."
)
# Wait indefinitely - process terminated via signal handlers
await asyncio.Event().wait()
def build_headless_namespace(config: Config) -> argparse.Namespace: def build_headless_namespace(config: Config) -> argparse.Namespace:
"""Build an argparse Namespace from engine_args for vLLM's run_headless(). """Build an argparse Namespace from engine_args for vLLM's run_headless().
...@@ -339,11 +326,12 @@ def setup_kv_event_publisher( ...@@ -339,11 +326,12 @@ def setup_kv_event_publisher(
) )
return None return None
# Get data_parallel_size to create publishers for all dp_ranks # Get DP rank range managed by this worker to create publishers for corresponding dp_ranks,
data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1) # all served workers should cover all ranks.
dp_start, dp_size = get_dp_range_for_worker(vllm_config)
kv_publishers = [] kv_publishers = []
for dp_rank in range(data_parallel_size): for dp_rank in range(dp_start, dp_start + dp_size):
if consolidator_enabled: if consolidator_enabled:
# TODO: Use different port for each dp_rank once KVBM supports DP # TODO: Use different port for each dp_rank once KVBM supports DP
zmq_endpoint = f"tcp://127.0.0.1:{consolidator_port}" zmq_endpoint = f"tcp://127.0.0.1:{consolidator_port}"
...@@ -561,8 +549,9 @@ async def register_vllm_model( ...@@ -561,8 +549,9 @@ async def register_vllm_model(
runtime_config.reasoning_parser = config.dyn_reasoning_parser runtime_config.reasoning_parser = config.dyn_reasoning_parser
# Get data_parallel_size from vllm_config (defaults to 1) # Get data_parallel_size from vllm_config (defaults to 1)
data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1) dp_range = get_dp_range_for_worker(vllm_config)
runtime_config.data_parallel_size = data_parallel_size runtime_config.data_parallel_start_rank = dp_range[0]
runtime_config.data_parallel_size = dp_range[1]
# Configure media decoder for frontend image decoding when enabled # Configure media decoder for frontend image decoding when enabled
# This enables frontend to decode images and transfer via NIXL RDMA # This enables frontend to decode images and transfer via NIXL RDMA
...@@ -675,10 +664,6 @@ async def init_prefill( ...@@ -675,10 +664,6 @@ async def init_prefill(
runtime.register_engine_route("wake_up", handler.wake_up) runtime.register_engine_route("wake_up", handler.wake_up)
logger.info("Registered engine routes: /engine/sleep, /engine/wake_up") logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
# Handle non-leader nodes - don't serve endpoints
if config.engine_args.data_parallel_rank:
await _handle_non_leader_node(config.engine_args.data_parallel_rank)
return
shutdown_endpoints[:] = [generate_endpoint, clear_endpoint] shutdown_endpoints[:] = [generate_endpoint, clear_endpoint]
# Register prefill model with ModelType.Prefill # Register prefill model with ModelType.Prefill
...@@ -791,7 +776,6 @@ async def init( ...@@ -791,7 +776,6 @@ async def init(
factory = StatLoggerFactory( factory = StatLoggerFactory(
endpoint=generate_endpoint, endpoint=generate_endpoint,
component_gauges=component_gauges, component_gauges=component_gauges,
dp_rank=config.engine_args.data_parallel_rank or 0,
) )
else: else:
# Factory is created without component_gauges; setup_vllm_engine() will # Factory is created without component_gauges; setup_vllm_engine() will
...@@ -799,7 +783,6 @@ async def init( ...@@ -799,7 +783,6 @@ async def init(
# on the factory before vLLM calls create_stat_logger(). # on the factory before vLLM calls create_stat_logger().
factory = StatLoggerFactory( factory = StatLoggerFactory(
endpoint=generate_endpoint, endpoint=generate_endpoint,
dp_rank=config.engine_args.data_parallel_rank or 0,
) )
( (
engine_client, engine_client,
...@@ -858,11 +841,6 @@ async def init( ...@@ -858,11 +841,6 @@ async def init(
runtime.register_engine_route("wake_up", handler.wake_up) runtime.register_engine_route("wake_up", handler.wake_up)
logger.info("Registered engine routes: /engine/sleep, /engine/wake_up") logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
# Handle non-leader nodes - don't serve endpoints
if config.engine_args.data_parallel_rank:
await _handle_non_leader_node(config.engine_args.data_parallel_rank)
return
# Parse endpoint types from --endpoint-types flag # Parse endpoint types from --endpoint-types flag
model_type = parse_endpoint_types(config.endpoint_types) model_type = parse_endpoint_types(config.endpoint_types)
logger.info(f"Registering model with endpoint types: {config.endpoint_types}") logger.info(f"Registering model with endpoint types: {config.endpoint_types}")
...@@ -1011,11 +989,6 @@ async def init_omni( ...@@ -1011,11 +989,6 @@ async def init_omni(
# Set up metrics collection for vLLM and LMCache metrics # Set up metrics collection for vLLM and LMCache metrics
setup_metrics_collection(config, generate_endpoint, logger) setup_metrics_collection(config, generate_endpoint, logger)
# Handle non-leader nodes - don't serve endpoints
if config.engine_args.data_parallel_rank:
await _handle_non_leader_node(config.engine_args.data_parallel_rank)
return
# TODO: extend for multi-stage pipelines # TODO: extend for multi-stage pipelines
model_type = get_output_modalities(config.output_modalities, config.model) model_type = get_output_modalities(config.output_modalities, config.model)
if model_type is None: if model_type is None:
......
...@@ -19,24 +19,6 @@ from dynamo.runtime import Endpoint ...@@ -19,24 +19,6 @@ from dynamo.runtime import Endpoint
DYNAMO_COMPONENT_REGISTRY = CollectorRegistry() DYNAMO_COMPONENT_REGISTRY = CollectorRegistry()
class NullStatLogger(StatLoggerBase):
def __init__(self):
pass
def record(
self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
engine_idx: int = 0,
*args,
**kwargs,
):
pass
def log_engine_initialized(self):
pass
class DynamoStatLoggerPublisher(StatLoggerBase): class DynamoStatLoggerPublisher(StatLoggerBase):
"""Stat logger publisher. Wrapper for the WorkerMetricsPublisher to match the StatLoggerBase interface.""" """Stat logger publisher. Wrapper for the WorkerMetricsPublisher to match the StatLoggerBase interface."""
...@@ -106,22 +88,17 @@ class StatLoggerFactory: ...@@ -106,22 +88,17 @@ class StatLoggerFactory:
self, self,
endpoint: Endpoint, endpoint: Endpoint,
component_gauges: Optional[LLMBackendMetrics] = None, component_gauges: Optional[LLMBackendMetrics] = None,
dp_rank: int = 0,
) -> None: ) -> None:
self.endpoint = endpoint self.endpoint = endpoint
self.component_gauges = component_gauges self.component_gauges = component_gauges
self.created_logger: Optional[DynamoStatLoggerPublisher] = None self.created_logger: Optional[DynamoStatLoggerPublisher] = None
self.dp_rank = dp_rank
def create_stat_logger(self, dp_rank: int) -> StatLoggerBase: def create_stat_logger(self, dp_rank: int) -> StatLoggerBase:
if self.dp_rank != dp_rank:
return NullStatLogger()
# component_gauges must be set by setup_vllm_engine() before vLLM # component_gauges must be set by setup_vllm_engine() before vLLM
# calls create_stat_logger() during engine initialization. # calls create_stat_logger() during engine initialization.
assert ( assert (
self.component_gauges is not None self.component_gauges is not None
), "component_gauges must be set before creating stat loggers" ), "component_gauges must be set before creating stat loggers"
logger = DynamoStatLoggerPublisher( logger = DynamoStatLoggerPublisher(
endpoint=self.endpoint, endpoint=self.endpoint,
dp_rank=dp_rank, dp_rank=dp_rank,
......
...@@ -35,16 +35,16 @@ python -m dynamo.frontend --router-mode kv & ...@@ -35,16 +35,16 @@ python -m dynamo.frontend --router-mode kv &
# Routing to DP workers managed by Dynamo # Routing to DP workers managed by Dynamo
# Chose Qwen3-30B because its a small MOE that can fit on smaller GPUs (L40S for example) # Chose Qwen3-30B because its a small MOE that can fit on smaller GPUs (L40S for example)
# --enforce-eager is added for quick deployment. for production use, need to remove this flag # --enforce-eager is added for quick deployment. for production use, need to remove this flag
for i in {0..3}; do VLLM_NIXL_SIDE_CHANNEL_PORT=20096 \
VLLM_NIXL_SIDE_CHANNEL_PORT=$((20096 + i)) \ python3 -m dynamo.vllm \
CUDA_VISIBLE_DEVICES=$i python3 -m dynamo.vllm \ --model Qwen/Qwen3-30B-A3B \
--model "$MODEL" \ --data-parallel-hybrid-lb \
--data-parallel-rank $i \ --data-parallel-size 4 \
--data-parallel-size 4 \ --data-parallel-size-local 4 \
--enable-expert-parallel \ --data-parallel-start-rank 0 \
--enforce-eager \ --enable-expert-parallel \
--kv-events-config "{\"publisher\":\"zmq\",\"topic\":\"kv-events\",\"endpoint\":\"tcp://*:$((20080 + i))\",\"enable_kv_cache_events\":true}" & --enforce-eager \
done --kv-events-config "{\"publisher\":\"zmq\",\"topic\":\"kv-events\",\"endpoint\":\"tcp://*:20080\",\"enable_kv_cache_events\":true}" &
echo "All workers starting. (press Ctrl+C to stop)..." echo "All workers starting. (press Ctrl+C to stop)..."
wait wait
...@@ -116,25 +116,24 @@ mkdir -p $LOG_DIR ...@@ -116,25 +116,24 @@ mkdir -p $LOG_DIR
# the GPU memory requires for vLLM reservation and runtime spike (not # the GPU memory requires for vLLM reservation and runtime spike (not
# reserved by vLLM) can be different and cause model fails to start, # reserved by vLLM) can be different and cause model fails to start,
# adjust '--gpu-memory-utilization' as needed # adjust '--gpu-memory-utilization' as needed
for ((i=0; i<GPUS_PER_NODE; i++)); do dp_start_rank=$((NODE_RANK * GPUS_PER_NODE))
dp_rank=$((i + NODE_RANK * GPUS_PER_NODE)) VLLM_NIXL_SIDE_CHANNEL_PORT=20096 \
CUDA_VISIBLE_DEVICES=$i \ VLLM_ALL2ALL_BACKEND="deepep_low_latency" \
VLLM_NIXL_SIDE_CHANNEL_PORT=$((20096 + i)) \ VLLM_USE_DEEP_GEMM=1 \
VLLM_ALL2ALL_BACKEND="deepep_low_latency" \ VLLM_RANDOMIZE_DP_DUMMY_INPUTS=1 \
VLLM_USE_DEEP_GEMM=1 \ python3 -m dynamo.vllm \
VLLM_RANDOMIZE_DP_DUMMY_INPUTS=1 \ --model $MODEL \
python3 -m dynamo.vllm \ --data-parallel-hybrid-lb \
--model $MODEL \ --data-parallel-size $DATA_PARALLEL_SIZE \
--data_parallel_size $DATA_PARALLEL_SIZE \ --data-parallel-size-local $GPUS_PER_NODE \
--data-parallel-rank $dp_rank \ --data-parallel-start-rank $dp_start_rank \
--enable-expert-parallel \ --enable-expert-parallel \
--max-model-len 4096 \ --max-model-len 4096 \
--data-parallel-address $MASTER_ADDR \ --data-parallel-address $MASTER_ADDR \
--data-parallel-rpc-port 13345 \ --data-parallel-rpc-port 13345 \
--gpu-memory-utilization 0.91 \ --gpu-memory-utilization 0.91 \
--enforce-eager \ --enforce-eager \
--kv-events-config "{\"publisher\":\"zmq\",\"topic\":\"kv-events\",\"endpoint\":\"tcp://*:$((20080 + i))\",\"enable_kv_cache_events\":true}" 2>&1 | tee $LOG_DIR/dsr1_dep_${dp_rank}.log & --kv-events-config "{\"publisher\":\"zmq\",\"topic\":\"kv-events\",\"endpoint\":\"tcp://*:20080\",\"enable_kv_cache_events\":true}" 2>&1 | tee $LOG_DIR/dsr1_dep_${dp_start_rank}.log &
done
echo "All workers starting. (press Ctrl+C to stop)..." echo "All workers starting. (press Ctrl+C to stop)..."
wait wait
...@@ -45,6 +45,11 @@ impl ModelRuntimeConfig { ...@@ -45,6 +45,11 @@ impl ModelRuntimeConfig {
self.inner.reasoning_parser = reasoning_parser; self.inner.reasoning_parser = reasoning_parser;
} }
#[setter]
fn set_data_parallel_start_rank(&mut self, data_parallel_start_rank: u32) {
self.inner.data_parallel_start_rank = data_parallel_start_rank;
}
#[setter] #[setter]
fn set_data_parallel_size(&mut self, data_parallel_size: u32) { fn set_data_parallel_size(&mut self, data_parallel_size: u32) {
self.inner.data_parallel_size = data_parallel_size; self.inner.data_parallel_size = data_parallel_size;
......
...@@ -289,11 +289,12 @@ async fn run_benchmark( ...@@ -289,11 +289,12 @@ async fn run_benchmark(
// Total bench workers = trace workers × duplication factor. // Total bench workers = trace workers × duplication factor.
// Each gets a unique WorkerWithDpRank in the shared multi-worker. // Each gets a unique WorkerWithDpRank in the shared multi-worker.
let total_workers = num_trace_workers * inference_worker_duplication_factor; let total_workers = num_trace_workers * inference_worker_duplication_factor;
let dp_sizes: HashMap<u64, u32> = (0..total_workers as u64).map(|id| (id, 1)).collect(); let dp_range: HashMap<u64, (u32, u32)> =
(0..total_workers as u64).map(|id| (id, (0, 1))).collect();
let multi = Arc::new(ActiveSequencesMultiWorker::new( let multi = Arc::new(ActiveSequencesMultiWorker::new(
NoopSequencePublisher, NoopSequencePublisher,
block_size as usize, block_size as usize,
dp_sizes, dp_range,
false, false,
0, 0,
"bench", "bench",
......
...@@ -96,6 +96,7 @@ pub fn compute_seq_hash_for_block(block_hashes: &[LocalBlockHash]) -> Vec<Sequen ...@@ -96,6 +96,7 @@ pub fn compute_seq_hash_for_block(block_hashes: &[LocalBlockHash]) -> Vec<Sequen
/// ///
/// `ModelRuntimeConfig` (in `lib/llm`) implements this directly so no adapter type is needed. /// `ModelRuntimeConfig` (in `lib/llm`) implements this directly so no adapter type is needed.
pub trait WorkerConfigLike { pub trait WorkerConfigLike {
fn data_parallel_start_rank(&self) -> u32;
fn data_parallel_size(&self) -> u32; fn data_parallel_size(&self) -> u32;
fn max_num_batched_tokens(&self) -> Option<u64>; fn max_num_batched_tokens(&self) -> Option<u64>;
fn total_kv_blocks(&self) -> Option<u64>; fn total_kv_blocks(&self) -> Option<u64>;
......
...@@ -209,11 +209,12 @@ impl<P: SequencePublisher + 'static, C: WorkerConfigLike> SchedulerQueue<P, C> { ...@@ -209,11 +209,12 @@ impl<P: SequencePublisher + 'static, C: WorkerConfigLike> SchedulerQueue<P, C> {
for (&worker_id, config) in configs.iter() { for (&worker_id, config) in configs.iter() {
let dp_size = config.data_parallel_size(); let dp_size = config.data_parallel_size();
let dp_start_rank = config.data_parallel_start_rank();
let max_batched = config let max_batched = config
.max_num_batched_tokens() .max_num_batched_tokens()
.unwrap_or(DEFAULT_MAX_BATCHED_TOKENS); .unwrap_or(DEFAULT_MAX_BATCHED_TOKENS);
for dp_rank in 0..dp_size { for dp_rank in dp_start_rank..dp_start_rank + dp_size {
let worker = WorkerWithDpRank::new(worker_id, dp_rank); let worker = WorkerWithDpRank::new(worker_id, dp_rank);
let tokens = active_tokens.get(&worker).copied().unwrap_or(0); let tokens = active_tokens.get(&worker).copied().unwrap_or(0);
if (tokens as f64) <= threshold * (max_batched as f64) { if (tokens as f64) <= threshold * (max_batched as f64) {
...@@ -247,11 +248,12 @@ mod tests { ...@@ -247,11 +248,12 @@ mod tests {
Arc<SchedulerQueue<NoopSequencePublisher, SimpleWorkerConfig>>, Arc<SchedulerQueue<NoopSequencePublisher, SimpleWorkerConfig>>,
Arc<ActiveSequencesMultiWorker<NoopSequencePublisher>>, Arc<ActiveSequencesMultiWorker<NoopSequencePublisher>>,
) { ) {
let dp_sizes: HashMap<u64, u32> = (0..num_workers as u64).map(|id| (id, 1)).collect(); let dp_range: HashMap<u64, (u32, u32)> =
(0..num_workers as u64).map(|id| (id, (0, 1))).collect();
let slots = Arc::new(ActiveSequencesMultiWorker::new( let slots = Arc::new(ActiveSequencesMultiWorker::new(
NoopSequencePublisher, NoopSequencePublisher,
block_size as usize, block_size as usize,
dp_sizes, dp_range,
false, false,
0, 0,
"test", "test",
......
...@@ -129,8 +129,10 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector { ...@@ -129,8 +129,10 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
.filter(|(wid, _)| allowed_ids.is_none_or(|ids| ids.contains(wid))) .filter(|(wid, _)| allowed_ids.is_none_or(|ids| ids.contains(wid)))
{ {
let data_parallel_size = config.data_parallel_size(); let data_parallel_size = config.data_parallel_size();
let data_parallel_start_rank = config.data_parallel_start_rank();
for dp_rank in 0..data_parallel_size { for dp_rank in data_parallel_start_rank..(data_parallel_start_rank + data_parallel_size)
{
let worker = WorkerWithDpRank::new(*worker_id, dp_rank); let worker = WorkerWithDpRank::new(*worker_id, dp_rank);
let overlap = *overlaps.get(&worker).unwrap_or(&0); let overlap = *overlaps.get(&worker).unwrap_or(&0);
......
...@@ -103,11 +103,11 @@ struct WorkerTable { ...@@ -103,11 +103,11 @@ struct WorkerTable {
} }
impl WorkerTable { impl WorkerTable {
fn new(block_size: usize, dp_sizes: &HashMap<u64, u32>) -> Self { fn new(block_size: usize, dp_range: &HashMap<u64, (u32, u32)>) -> Self {
let mut slots = Vec::new(); let mut slots = Vec::new();
let mut index = HashMap::new(); let mut index = HashMap::new();
for (&worker_id, &dp_size) in dp_sizes { for (&worker_id, &(dp_start, dp_size)) in dp_range {
for dp_rank in 0..dp_size { for dp_rank in dp_start..dp_start + dp_size {
let worker = WorkerWithDpRank::new(worker_id, dp_rank); let worker = WorkerWithDpRank::new(worker_id, dp_rank);
let idx = slots.len(); let idx = slots.len();
slots.push((worker, RwLock::new(ActiveSequences::new(block_size)))); slots.push((worker, RwLock::new(ActiveSequences::new(block_size))));
...@@ -149,7 +149,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -149,7 +149,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
pub fn new( pub fn new(
publisher: P, publisher: P,
block_size: usize, block_size: usize,
dp_sizes: HashMap<u64, u32>, dp_range: HashMap<u64, (u32, u32)>,
replica_sync: bool, replica_sync: bool,
router_id: u64, router_id: u64,
worker_type: &'static str, worker_type: &'static str,
...@@ -157,7 +157,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -157,7 +157,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
assert!(block_size > 1, "block_size must be greater than 1"); assert!(block_size > 1, "block_size must be greater than 1");
Self { Self {
workers: RwLock::new(WorkerTable::new(block_size, &dp_sizes)), workers: RwLock::new(WorkerTable::new(block_size, &dp_range)),
request_to_worker: DashMap::new(), request_to_worker: DashMap::new(),
request_to_lora: DashMap::new(), request_to_lora: DashMap::new(),
block_size, block_size,
...@@ -276,13 +276,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -276,13 +276,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
/// Update the set of workers, adding and removing as needed. /// Update the set of workers, adding and removing as needed.
/// ///
/// `new_dp_sizes` maps worker IDs to their data-parallel size. /// `new_dp_range` maps worker IDs to their data-parallel range (start, size).
pub fn update_workers(&self, new_dp_sizes: &HashMap<u64, u32>) { pub fn update_workers(&self, new_dp_range: &HashMap<u64, (u32, u32)>) {
let mut table = self.workers.write(); let mut table = self.workers.write();
let mut target_workers: HashSet<WorkerWithDpRank> = HashSet::new(); let mut target_workers: HashSet<WorkerWithDpRank> = HashSet::new();
for (&worker_id, &dp_size) in new_dp_sizes { for (&worker_id, &(dp_start, dp_size)) in new_dp_range {
for dp_rank in 0..dp_size { for dp_rank in dp_start..(dp_start + dp_size) {
target_workers.insert(WorkerWithDpRank::new(worker_id, dp_rank)); target_workers.insert(WorkerWithDpRank::new(worker_id, dp_rank));
} }
} }
......
...@@ -85,6 +85,7 @@ impl SequencePublisher for NoopSequencePublisher { ...@@ -85,6 +85,7 @@ impl SequencePublisher for NoopSequencePublisher {
/// Minimal [`WorkerConfigLike`] for scheduler/queue tests and benchmarks. /// Minimal [`WorkerConfigLike`] for scheduler/queue tests and benchmarks.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct SimpleWorkerConfig { pub struct SimpleWorkerConfig {
pub data_parallel_start_rank: u32,
pub data_parallel_size: u32, pub data_parallel_size: u32,
pub max_num_batched_tokens: Option<u64>, pub max_num_batched_tokens: Option<u64>,
pub total_kv_blocks: Option<u64>, pub total_kv_blocks: Option<u64>,
...@@ -93,6 +94,7 @@ pub struct SimpleWorkerConfig { ...@@ -93,6 +94,7 @@ pub struct SimpleWorkerConfig {
impl Default for SimpleWorkerConfig { impl Default for SimpleWorkerConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
data_parallel_start_rank: 0,
data_parallel_size: 1, data_parallel_size: 1,
max_num_batched_tokens: None, max_num_batched_tokens: None,
total_kv_blocks: None, total_kv_blocks: None,
...@@ -101,6 +103,10 @@ impl Default for SimpleWorkerConfig { ...@@ -101,6 +103,10 @@ impl Default for SimpleWorkerConfig {
} }
impl WorkerConfigLike for SimpleWorkerConfig { impl WorkerConfigLike for SimpleWorkerConfig {
fn data_parallel_start_rank(&self) -> u32 {
self.data_parallel_start_rank
}
fn data_parallel_size(&self) -> u32 { fn data_parallel_size(&self) -> u32 {
self.data_parallel_size self.data_parallel_size
} }
......
...@@ -456,22 +456,25 @@ impl WorkerLoadMonitor for KvWorkerMonitor { ...@@ -456,22 +456,25 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
for (lease_id, runtime_config) in runtime_configs.iter() { for (lease_id, runtime_config) in runtime_configs.iter() {
let mut state = worker_load_states.entry(*lease_id).or_default(); let mut state = worker_load_states.entry(*lease_id).or_default();
let dp_start = runtime_config.data_parallel_start_rank;
let dp_end = dp_start + runtime_config.data_parallel_size;
// Track dp_ranks for this worker (for cleanup when worker disappears) // Track dp_ranks for this worker (for cleanup when worker disappears)
let dp_ranks_set = known_worker_dp_ranks.entry(*lease_id).or_default(); let dp_ranks_set = known_worker_dp_ranks.entry(*lease_id).or_default();
for dp_rank in 0..runtime_config.data_parallel_size { for dp_rank in dp_start..dp_end {
dp_ranks_set.insert(dp_rank); dp_ranks_set.insert(dp_rank);
} }
// Populate total_blocks for all dp_ranks (they share the same total) // Populate total_blocks for all dp_ranks (they share the same total)
if let Some(total_blocks) = runtime_config.total_kv_blocks { if let Some(total_blocks) = runtime_config.total_kv_blocks {
for dp_rank in 0..runtime_config.data_parallel_size { for dp_rank in dp_start..dp_end {
state.kv_total_blocks.insert(dp_rank, total_blocks); state.kv_total_blocks.insert(dp_rank, total_blocks);
} }
} }
// Populate max_num_batched_tokens for all dp_ranks // Populate max_num_batched_tokens for all dp_ranks
if let Some(max_batched) = runtime_config.max_num_batched_tokens { if let Some(max_batched) = runtime_config.max_num_batched_tokens {
for dp_rank in 0..runtime_config.data_parallel_size { for dp_rank in dp_start..dp_end {
state.max_num_batched_tokens.insert(dp_rank, max_batched); state.max_num_batched_tokens.insert(dp_rank, max_batched);
} }
} }
......
...@@ -87,11 +87,11 @@ impl KvScheduler { ...@@ -87,11 +87,11 @@ impl KvScheduler {
let current_workers = monitor_rx.borrow_and_update().clone(); let current_workers = monitor_rx.borrow_and_update().clone();
if current_workers != last_workers { if current_workers != last_workers {
let dp_sizes: HashMap<u64, u32> = current_workers let dp_range: HashMap<u64, (u32, u32)> = current_workers
.iter() .iter()
.map(|(&id, c)| (id, c.data_parallel_size)) .map(|(&id, c)| (id, (c.data_parallel_start_rank, c.data_parallel_size)))
.collect(); .collect();
slots_monitor.update_workers(&dp_sizes); slots_monitor.update_workers(&dp_range);
last_workers = current_workers; last_workers = current_workers;
} }
} }
......
...@@ -103,15 +103,20 @@ pub async fn create_multi_worker_sequences( ...@@ -103,15 +103,20 @@ pub async fn create_multi_worker_sequences(
metrics_publisher, metrics_publisher,
}; };
let dp_sizes: HashMap<u64, u32> = workers_with_configs let dp_range: HashMap<u64, (u32, u32)> = workers_with_configs
.into_iter() .into_iter()
.map(|(id, config)| (id, config.data_parallel_size)) .map(|(id, config)| {
(
id,
(config.data_parallel_start_rank, config.data_parallel_size),
)
})
.collect(); .collect();
let multi_worker = ActiveSequencesMultiWorker::new( let multi_worker = ActiveSequencesMultiWorker::new(
publisher, publisher,
block_size, block_size,
dp_sizes, dp_range,
replica_sync, replica_sync,
router_id, router_id,
worker_type, worker_type,
......
...@@ -28,6 +28,10 @@ pub struct ModelRuntimeConfig { ...@@ -28,6 +28,10 @@ pub struct ModelRuntimeConfig {
pub reasoning_parser: Option<String>, pub reasoning_parser: Option<String>,
/// Starting rank of data parallel ranks for this worker (0 if DP not enabled)
#[serde(default = "default_data_parallel_start_rank")]
pub data_parallel_start_rank: u32,
/// Total number of data parallel ranks for this worker (1 if DP not enabled) /// Total number of data parallel ranks for this worker (1 if DP not enabled)
#[serde(default = "default_data_parallel_size")] #[serde(default = "default_data_parallel_size")]
pub data_parallel_size: u32, pub data_parallel_size: u32,
...@@ -55,6 +59,10 @@ pub struct ModelRuntimeConfig { ...@@ -55,6 +59,10 @@ pub struct ModelRuntimeConfig {
pub disaggregated_endpoint: Option<DisaggregatedEndpoint>, pub disaggregated_endpoint: Option<DisaggregatedEndpoint>,
} }
const fn default_data_parallel_start_rank() -> u32 {
0
}
const fn default_data_parallel_size() -> u32 { const fn default_data_parallel_size() -> u32 {
1 1
} }
...@@ -71,6 +79,7 @@ impl Default for ModelRuntimeConfig { ...@@ -71,6 +79,7 @@ impl Default for ModelRuntimeConfig {
max_num_batched_tokens: None, max_num_batched_tokens: None,
tool_call_parser: None, tool_call_parser: None,
reasoning_parser: None, reasoning_parser: None,
data_parallel_start_rank: default_data_parallel_start_rank(),
data_parallel_size: default_data_parallel_size(), data_parallel_size: default_data_parallel_size(),
enable_local_indexer: true, enable_local_indexer: true,
runtime_data: HashMap::new(), runtime_data: HashMap::new(),
...@@ -81,6 +90,10 @@ impl Default for ModelRuntimeConfig { ...@@ -81,6 +90,10 @@ impl Default for ModelRuntimeConfig {
} }
impl dynamo_kv_router::WorkerConfigLike for ModelRuntimeConfig { impl dynamo_kv_router::WorkerConfigLike for ModelRuntimeConfig {
fn data_parallel_start_rank(&self) -> u32 {
self.data_parallel_start_rank
}
fn data_parallel_size(&self) -> u32 { fn data_parallel_size(&self) -> u32 {
self.data_parallel_size self.data_parallel_size
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment