Unverified Commit 90d74637 authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

fix: properly setup and register vLLM worker for external / hybrid load...


fix: properly setup and register vLLM worker for external / hybrid load balancing. Update launch script (#6695)
Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
parent 9254e3d4
......@@ -16,6 +16,7 @@ from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Final
import torch
from vllm.config import VllmConfig
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
......@@ -260,6 +261,32 @@ def build_sampling_params_openai(
return sampling_params
def get_dp_range_for_worker(vllm_config: VllmConfig) -> range:
"""
Get the global DP rank range that this worker is responsible for based on vLLM config.
Note that the 'vllm_config' is normalized so the load balancing flags are set properly.
The return value is in the format of (start_dp_rank, managed_dp_size)."""
if vllm_config.parallel_config.data_parallel_external_lb:
# external load balancing, each worker is responsible for exactly 1 rank
return (vllm_config.parallel_config.data_parallel_rank, 1)
elif vllm_config.parallel_config.data_parallel_hybrid_lb:
# hybrid load balancing, each worker is responsible for a subset of local ranks
return (
vllm_config.parallel_config.data_parallel_rank,
vllm_config.parallel_config.data_parallel_size_local,
)
else:
# internal load balancing, the worker is responsible for all DP ranks
logger.warning(
"vLLM selects internal DP load balancing. If you are launching multiple workers for DP deployment,"
" hybrid or external load balancing is recommended."
)
return (
vllm_config.parallel_config.data_parallel_rank,
vllm_config.parallel_config.data_parallel_size,
)
class BaseWorkerHandler(ABC):
"""
Request handler for the generate and clear_kv_blocks endpoints.
......@@ -302,6 +329,8 @@ class BaseWorkerHandler(ABC):
self.use_vllm_tokenizer = use_vllm_tokenizer
self.dp_range = get_dp_range_for_worker(self.engine_client.vllm_config)
# Initialize InputParamManager for text-in-text-out mode
tokenizer = None
if use_vllm_tokenizer and hasattr(engine, "tokenizer"):
......@@ -463,6 +492,21 @@ class BaseWorkerHandler(ABC):
if temp_dir is not None:
self.temp_dirs.append(temp_dir)
def _to_local_dp_rank(self, dp_rank: int | None) -> int | None:
"""Convert global DP rank to local DP rank based on engine config."""
if dp_rank is None:
return None
if dp_rank < self.dp_range[0] or dp_rank >= self.dp_range[0] + self.dp_range[1]:
logger.warning(
f"Received DP rank {dp_rank} is out of range [{self.dp_range[0]} - {self.dp_range[0] + self.dp_range[1]}), fallback to vLLM internal DP selection"
)
return None
local_dp_rank = (dp_rank - self.dp_range[0]) % self.dp_range[1]
logger.debug(
f"Converted global DP rank {dp_rank} to local DP rank {local_dp_rank}"
)
return local_dp_rank
def _resolve_lora_request(self, model_name: str | None) -> LoRARequest | None:
"""Return a LoRARequest if model_name is a loaded adapter, else None."""
if model_name and (lora := self.loaded_loras.get(model_name)):
......@@ -1317,7 +1361,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
f"Decode request {request_id} has no LoRA specified (model: {model_name})"
)
routing = request.get("routing") or {}
dp_rank = routing.get("dp_rank")
dp_rank = self._to_local_dp_rank(routing.get("dp_rank"))
priority = routing.get("priority", 0)
trace_headers = build_trace_headers(context)
......@@ -1364,7 +1408,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
)
routing = request.get("routing") or {}
dp_rank = routing.get("dp_rank")
dp_rank = self._to_local_dp_rank(routing.get("dp_rank"))
priority = routing.get("priority", 0)
openai_request_id = request.get("id") or request.get("request_id", request_id)
previous_text = ""
......@@ -1525,7 +1569,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
)
routing = request.get("routing") or {}
dp_rank = routing.get("dp_rank")
dp_rank = self._to_local_dp_rank(routing.get("dp_rank"))
priority = routing.get("priority", 0)
trace_headers = build_trace_headers(context)
......
......@@ -55,7 +55,7 @@ from dynamo.vllm.worker_factory import WorkerFactory
from .args import Config, _uses_dynamo_connector, parse_args
from .checkpoint_restore import get_checkpoint_config
from .constants import DisaggregationMode
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler, get_dp_range_for_worker
from .health_check import (
VllmHealthCheckPayload,
VllmOmniHealthCheckPayload,
......@@ -69,19 +69,6 @@ shutdown_endpoints: list = []
CHECKPOINT_SLEEP_MODE_LEVEL = 1
async def _handle_non_leader_node(dp_rank: int) -> None:
"""
Handle non-leader node (data_parallel_rank >= 1) in multi-node deployments.
Non-leader nodes run vLLM workers but don't serve Dynamo endpoints.
"""
logger.info(
f"Non-leader node detected (data_parallel_rank={dp_rank}). "
"Skipping endpoint serving."
)
# Wait indefinitely - process terminated via signal handlers
await asyncio.Event().wait()
def build_headless_namespace(config: Config) -> argparse.Namespace:
"""Build an argparse Namespace from engine_args for vLLM's run_headless().
......@@ -339,11 +326,12 @@ def setup_kv_event_publisher(
)
return None
# Get data_parallel_size to create publishers for all dp_ranks
data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
# Get DP rank range managed by this worker to create publishers for corresponding dp_ranks,
# all served workers should cover all ranks.
dp_start, dp_size = get_dp_range_for_worker(vllm_config)
kv_publishers = []
for dp_rank in range(data_parallel_size):
for dp_rank in range(dp_start, dp_start + dp_size):
if consolidator_enabled:
# TODO: Use different port for each dp_rank once KVBM supports DP
zmq_endpoint = f"tcp://127.0.0.1:{consolidator_port}"
......@@ -561,8 +549,9 @@ async def register_vllm_model(
runtime_config.reasoning_parser = config.dyn_reasoning_parser
# Get data_parallel_size from vllm_config (defaults to 1)
data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
runtime_config.data_parallel_size = data_parallel_size
dp_range = get_dp_range_for_worker(vllm_config)
runtime_config.data_parallel_start_rank = dp_range[0]
runtime_config.data_parallel_size = dp_range[1]
# Configure media decoder for frontend image decoding when enabled
# This enables frontend to decode images and transfer via NIXL RDMA
......@@ -675,10 +664,6 @@ async def init_prefill(
runtime.register_engine_route("wake_up", handler.wake_up)
logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
# Handle non-leader nodes - don't serve endpoints
if config.engine_args.data_parallel_rank:
await _handle_non_leader_node(config.engine_args.data_parallel_rank)
return
shutdown_endpoints[:] = [generate_endpoint, clear_endpoint]
# Register prefill model with ModelType.Prefill
......@@ -791,7 +776,6 @@ async def init(
factory = StatLoggerFactory(
endpoint=generate_endpoint,
component_gauges=component_gauges,
dp_rank=config.engine_args.data_parallel_rank or 0,
)
else:
# Factory is created without component_gauges; setup_vllm_engine() will
......@@ -799,7 +783,6 @@ async def init(
# on the factory before vLLM calls create_stat_logger().
factory = StatLoggerFactory(
endpoint=generate_endpoint,
dp_rank=config.engine_args.data_parallel_rank or 0,
)
(
engine_client,
......@@ -858,11 +841,6 @@ async def init(
runtime.register_engine_route("wake_up", handler.wake_up)
logger.info("Registered engine routes: /engine/sleep, /engine/wake_up")
# Handle non-leader nodes - don't serve endpoints
if config.engine_args.data_parallel_rank:
await _handle_non_leader_node(config.engine_args.data_parallel_rank)
return
# Parse endpoint types from --endpoint-types flag
model_type = parse_endpoint_types(config.endpoint_types)
logger.info(f"Registering model with endpoint types: {config.endpoint_types}")
......@@ -1011,11 +989,6 @@ async def init_omni(
# Set up metrics collection for vLLM and LMCache metrics
setup_metrics_collection(config, generate_endpoint, logger)
# Handle non-leader nodes - don't serve endpoints
if config.engine_args.data_parallel_rank:
await _handle_non_leader_node(config.engine_args.data_parallel_rank)
return
# TODO: extend for multi-stage pipelines
model_type = get_output_modalities(config.output_modalities, config.model)
if model_type is None:
......
......@@ -19,24 +19,6 @@ from dynamo.runtime import Endpoint
DYNAMO_COMPONENT_REGISTRY = CollectorRegistry()
class NullStatLogger(StatLoggerBase):
def __init__(self):
pass
def record(
self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
engine_idx: int = 0,
*args,
**kwargs,
):
pass
def log_engine_initialized(self):
pass
class DynamoStatLoggerPublisher(StatLoggerBase):
"""Stat logger publisher. Wrapper for the WorkerMetricsPublisher to match the StatLoggerBase interface."""
......@@ -106,22 +88,17 @@ class StatLoggerFactory:
self,
endpoint: Endpoint,
component_gauges: Optional[LLMBackendMetrics] = None,
dp_rank: int = 0,
) -> None:
self.endpoint = endpoint
self.component_gauges = component_gauges
self.created_logger: Optional[DynamoStatLoggerPublisher] = None
self.dp_rank = dp_rank
def create_stat_logger(self, dp_rank: int) -> StatLoggerBase:
if self.dp_rank != dp_rank:
return NullStatLogger()
# component_gauges must be set by setup_vllm_engine() before vLLM
# calls create_stat_logger() during engine initialization.
assert (
self.component_gauges is not None
), "component_gauges must be set before creating stat loggers"
logger = DynamoStatLoggerPublisher(
endpoint=self.endpoint,
dp_rank=dp_rank,
......
......@@ -35,16 +35,16 @@ python -m dynamo.frontend --router-mode kv &
# Routing to DP workers managed by Dynamo
# Chose Qwen3-30B because its a small MOE that can fit on smaller GPUs (L40S for example)
# --enforce-eager is added for quick deployment. for production use, need to remove this flag
for i in {0..3}; do
VLLM_NIXL_SIDE_CHANNEL_PORT=$((20096 + i)) \
CUDA_VISIBLE_DEVICES=$i python3 -m dynamo.vllm \
--model "$MODEL" \
--data-parallel-rank $i \
--data-parallel-size 4 \
--enable-expert-parallel \
--enforce-eager \
--kv-events-config "{\"publisher\":\"zmq\",\"topic\":\"kv-events\",\"endpoint\":\"tcp://*:$((20080 + i))\",\"enable_kv_cache_events\":true}" &
done
VLLM_NIXL_SIDE_CHANNEL_PORT=20096 \
python3 -m dynamo.vllm \
--model Qwen/Qwen3-30B-A3B \
--data-parallel-hybrid-lb \
--data-parallel-size 4 \
--data-parallel-size-local 4 \
--data-parallel-start-rank 0 \
--enable-expert-parallel \
--enforce-eager \
--kv-events-config "{\"publisher\":\"zmq\",\"topic\":\"kv-events\",\"endpoint\":\"tcp://*:20080\",\"enable_kv_cache_events\":true}" &
echo "All workers starting. (press Ctrl+C to stop)..."
wait
......@@ -116,25 +116,24 @@ mkdir -p $LOG_DIR
# the GPU memory requires for vLLM reservation and runtime spike (not
# reserved by vLLM) can be different and cause model fails to start,
# adjust '--gpu-memory-utilization' as needed
for ((i=0; i<GPUS_PER_NODE; i++)); do
dp_rank=$((i + NODE_RANK * GPUS_PER_NODE))
CUDA_VISIBLE_DEVICES=$i \
VLLM_NIXL_SIDE_CHANNEL_PORT=$((20096 + i)) \
VLLM_ALL2ALL_BACKEND="deepep_low_latency" \
VLLM_USE_DEEP_GEMM=1 \
VLLM_RANDOMIZE_DP_DUMMY_INPUTS=1 \
python3 -m dynamo.vllm \
--model $MODEL \
--data_parallel_size $DATA_PARALLEL_SIZE \
--data-parallel-rank $dp_rank \
--enable-expert-parallel \
--max-model-len 4096 \
--data-parallel-address $MASTER_ADDR \
--data-parallel-rpc-port 13345 \
--gpu-memory-utilization 0.91 \
--enforce-eager \
--kv-events-config "{\"publisher\":\"zmq\",\"topic\":\"kv-events\",\"endpoint\":\"tcp://*:$((20080 + i))\",\"enable_kv_cache_events\":true}" 2>&1 | tee $LOG_DIR/dsr1_dep_${dp_rank}.log &
done
dp_start_rank=$((NODE_RANK * GPUS_PER_NODE))
VLLM_NIXL_SIDE_CHANNEL_PORT=20096 \
VLLM_ALL2ALL_BACKEND="deepep_low_latency" \
VLLM_USE_DEEP_GEMM=1 \
VLLM_RANDOMIZE_DP_DUMMY_INPUTS=1 \
python3 -m dynamo.vllm \
--model $MODEL \
--data-parallel-hybrid-lb \
--data-parallel-size $DATA_PARALLEL_SIZE \
--data-parallel-size-local $GPUS_PER_NODE \
--data-parallel-start-rank $dp_start_rank \
--enable-expert-parallel \
--max-model-len 4096 \
--data-parallel-address $MASTER_ADDR \
--data-parallel-rpc-port 13345 \
--gpu-memory-utilization 0.91 \
--enforce-eager \
--kv-events-config "{\"publisher\":\"zmq\",\"topic\":\"kv-events\",\"endpoint\":\"tcp://*:20080\",\"enable_kv_cache_events\":true}" 2>&1 | tee $LOG_DIR/dsr1_dep_${dp_start_rank}.log &
echo "All workers starting. (press Ctrl+C to stop)..."
wait
......@@ -45,6 +45,11 @@ impl ModelRuntimeConfig {
self.inner.reasoning_parser = reasoning_parser;
}
#[setter]
fn set_data_parallel_start_rank(&mut self, data_parallel_start_rank: u32) {
self.inner.data_parallel_start_rank = data_parallel_start_rank;
}
#[setter]
fn set_data_parallel_size(&mut self, data_parallel_size: u32) {
self.inner.data_parallel_size = data_parallel_size;
......
......@@ -289,11 +289,12 @@ async fn run_benchmark(
// Total bench workers = trace workers × duplication factor.
// Each gets a unique WorkerWithDpRank in the shared multi-worker.
let total_workers = num_trace_workers * inference_worker_duplication_factor;
let dp_sizes: HashMap<u64, u32> = (0..total_workers as u64).map(|id| (id, 1)).collect();
let dp_range: HashMap<u64, (u32, u32)> =
(0..total_workers as u64).map(|id| (id, (0, 1))).collect();
let multi = Arc::new(ActiveSequencesMultiWorker::new(
NoopSequencePublisher,
block_size as usize,
dp_sizes,
dp_range,
false,
0,
"bench",
......
......@@ -96,6 +96,7 @@ pub fn compute_seq_hash_for_block(block_hashes: &[LocalBlockHash]) -> Vec<Sequen
///
/// `ModelRuntimeConfig` (in `lib/llm`) implements this directly so no adapter type is needed.
pub trait WorkerConfigLike {
fn data_parallel_start_rank(&self) -> u32;
fn data_parallel_size(&self) -> u32;
fn max_num_batched_tokens(&self) -> Option<u64>;
fn total_kv_blocks(&self) -> Option<u64>;
......
......@@ -209,11 +209,12 @@ impl<P: SequencePublisher + 'static, C: WorkerConfigLike> SchedulerQueue<P, C> {
for (&worker_id, config) in configs.iter() {
let dp_size = config.data_parallel_size();
let dp_start_rank = config.data_parallel_start_rank();
let max_batched = config
.max_num_batched_tokens()
.unwrap_or(DEFAULT_MAX_BATCHED_TOKENS);
for dp_rank in 0..dp_size {
for dp_rank in dp_start_rank..dp_start_rank + dp_size {
let worker = WorkerWithDpRank::new(worker_id, dp_rank);
let tokens = active_tokens.get(&worker).copied().unwrap_or(0);
if (tokens as f64) <= threshold * (max_batched as f64) {
......@@ -247,11 +248,12 @@ mod tests {
Arc<SchedulerQueue<NoopSequencePublisher, SimpleWorkerConfig>>,
Arc<ActiveSequencesMultiWorker<NoopSequencePublisher>>,
) {
let dp_sizes: HashMap<u64, u32> = (0..num_workers as u64).map(|id| (id, 1)).collect();
let dp_range: HashMap<u64, (u32, u32)> =
(0..num_workers as u64).map(|id| (id, (0, 1))).collect();
let slots = Arc::new(ActiveSequencesMultiWorker::new(
NoopSequencePublisher,
block_size as usize,
dp_sizes,
dp_range,
false,
0,
"test",
......
......@@ -129,8 +129,10 @@ impl<C: WorkerConfigLike> WorkerSelector<C> for DefaultWorkerSelector {
.filter(|(wid, _)| allowed_ids.is_none_or(|ids| ids.contains(wid)))
{
let data_parallel_size = config.data_parallel_size();
let data_parallel_start_rank = config.data_parallel_start_rank();
for dp_rank in 0..data_parallel_size {
for dp_rank in data_parallel_start_rank..(data_parallel_start_rank + data_parallel_size)
{
let worker = WorkerWithDpRank::new(*worker_id, dp_rank);
let overlap = *overlaps.get(&worker).unwrap_or(&0);
......
......@@ -103,11 +103,11 @@ struct WorkerTable {
}
impl WorkerTable {
fn new(block_size: usize, dp_sizes: &HashMap<u64, u32>) -> Self {
fn new(block_size: usize, dp_range: &HashMap<u64, (u32, u32)>) -> Self {
let mut slots = Vec::new();
let mut index = HashMap::new();
for (&worker_id, &dp_size) in dp_sizes {
for dp_rank in 0..dp_size {
for (&worker_id, &(dp_start, dp_size)) in dp_range {
for dp_rank in dp_start..dp_start + dp_size {
let worker = WorkerWithDpRank::new(worker_id, dp_rank);
let idx = slots.len();
slots.push((worker, RwLock::new(ActiveSequences::new(block_size))));
......@@ -149,7 +149,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
pub fn new(
publisher: P,
block_size: usize,
dp_sizes: HashMap<u64, u32>,
dp_range: HashMap<u64, (u32, u32)>,
replica_sync: bool,
router_id: u64,
worker_type: &'static str,
......@@ -157,7 +157,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
assert!(block_size > 1, "block_size must be greater than 1");
Self {
workers: RwLock::new(WorkerTable::new(block_size, &dp_sizes)),
workers: RwLock::new(WorkerTable::new(block_size, &dp_range)),
request_to_worker: DashMap::new(),
request_to_lora: DashMap::new(),
block_size,
......@@ -276,13 +276,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
/// Update the set of workers, adding and removing as needed.
///
/// `new_dp_sizes` maps worker IDs to their data-parallel size.
pub fn update_workers(&self, new_dp_sizes: &HashMap<u64, u32>) {
/// `new_dp_range` maps worker IDs to their data-parallel range (start, size).
pub fn update_workers(&self, new_dp_range: &HashMap<u64, (u32, u32)>) {
let mut table = self.workers.write();
let mut target_workers: HashSet<WorkerWithDpRank> = HashSet::new();
for (&worker_id, &dp_size) in new_dp_sizes {
for dp_rank in 0..dp_size {
for (&worker_id, &(dp_start, dp_size)) in new_dp_range {
for dp_rank in dp_start..(dp_start + dp_size) {
target_workers.insert(WorkerWithDpRank::new(worker_id, dp_rank));
}
}
......
......@@ -85,6 +85,7 @@ impl SequencePublisher for NoopSequencePublisher {
/// Minimal [`WorkerConfigLike`] for scheduler/queue tests and benchmarks.
#[derive(Debug, Clone)]
pub struct SimpleWorkerConfig {
pub data_parallel_start_rank: u32,
pub data_parallel_size: u32,
pub max_num_batched_tokens: Option<u64>,
pub total_kv_blocks: Option<u64>,
......@@ -93,6 +94,7 @@ pub struct SimpleWorkerConfig {
impl Default for SimpleWorkerConfig {
fn default() -> Self {
Self {
data_parallel_start_rank: 0,
data_parallel_size: 1,
max_num_batched_tokens: None,
total_kv_blocks: None,
......@@ -101,6 +103,10 @@ impl Default for SimpleWorkerConfig {
}
impl WorkerConfigLike for SimpleWorkerConfig {
fn data_parallel_start_rank(&self) -> u32 {
self.data_parallel_start_rank
}
fn data_parallel_size(&self) -> u32 {
self.data_parallel_size
}
......
......@@ -456,22 +456,25 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
for (lease_id, runtime_config) in runtime_configs.iter() {
let mut state = worker_load_states.entry(*lease_id).or_default();
let dp_start = runtime_config.data_parallel_start_rank;
let dp_end = dp_start + runtime_config.data_parallel_size;
// Track dp_ranks for this worker (for cleanup when worker disappears)
let dp_ranks_set = known_worker_dp_ranks.entry(*lease_id).or_default();
for dp_rank in 0..runtime_config.data_parallel_size {
for dp_rank in dp_start..dp_end {
dp_ranks_set.insert(dp_rank);
}
// Populate total_blocks for all dp_ranks (they share the same total)
if let Some(total_blocks) = runtime_config.total_kv_blocks {
for dp_rank in 0..runtime_config.data_parallel_size {
for dp_rank in dp_start..dp_end {
state.kv_total_blocks.insert(dp_rank, total_blocks);
}
}
// Populate max_num_batched_tokens for all dp_ranks
if let Some(max_batched) = runtime_config.max_num_batched_tokens {
for dp_rank in 0..runtime_config.data_parallel_size {
for dp_rank in dp_start..dp_end {
state.max_num_batched_tokens.insert(dp_rank, max_batched);
}
}
......
......@@ -87,11 +87,11 @@ impl KvScheduler {
let current_workers = monitor_rx.borrow_and_update().clone();
if current_workers != last_workers {
let dp_sizes: HashMap<u64, u32> = current_workers
let dp_range: HashMap<u64, (u32, u32)> = current_workers
.iter()
.map(|(&id, c)| (id, c.data_parallel_size))
.map(|(&id, c)| (id, (c.data_parallel_start_rank, c.data_parallel_size)))
.collect();
slots_monitor.update_workers(&dp_sizes);
slots_monitor.update_workers(&dp_range);
last_workers = current_workers;
}
}
......
......@@ -103,15 +103,20 @@ pub async fn create_multi_worker_sequences(
metrics_publisher,
};
let dp_sizes: HashMap<u64, u32> = workers_with_configs
let dp_range: HashMap<u64, (u32, u32)> = workers_with_configs
.into_iter()
.map(|(id, config)| (id, config.data_parallel_size))
.map(|(id, config)| {
(
id,
(config.data_parallel_start_rank, config.data_parallel_size),
)
})
.collect();
let multi_worker = ActiveSequencesMultiWorker::new(
publisher,
block_size,
dp_sizes,
dp_range,
replica_sync,
router_id,
worker_type,
......
......@@ -28,6 +28,10 @@ pub struct ModelRuntimeConfig {
pub reasoning_parser: Option<String>,
/// Starting rank of data parallel ranks for this worker (0 if DP not enabled)
#[serde(default = "default_data_parallel_start_rank")]
pub data_parallel_start_rank: u32,
/// Total number of data parallel ranks for this worker (1 if DP not enabled)
#[serde(default = "default_data_parallel_size")]
pub data_parallel_size: u32,
......@@ -55,6 +59,10 @@ pub struct ModelRuntimeConfig {
pub disaggregated_endpoint: Option<DisaggregatedEndpoint>,
}
const fn default_data_parallel_start_rank() -> u32 {
0
}
const fn default_data_parallel_size() -> u32 {
1
}
......@@ -71,6 +79,7 @@ impl Default for ModelRuntimeConfig {
max_num_batched_tokens: None,
tool_call_parser: None,
reasoning_parser: None,
data_parallel_start_rank: default_data_parallel_start_rank(),
data_parallel_size: default_data_parallel_size(),
enable_local_indexer: true,
runtime_data: HashMap::new(),
......@@ -81,6 +90,10 @@ impl Default for ModelRuntimeConfig {
}
impl dynamo_kv_router::WorkerConfigLike for ModelRuntimeConfig {
fn data_parallel_start_rank(&self) -> u32 {
self.data_parallel_start_rank
}
fn data_parallel_size(&self) -> u32 {
self.data_parallel_size
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment