Unverified Commit 1caf6659 authored by Aniket Kulkarni's avatar Aniket Kulkarni Committed by GitHub
Browse files

feat(kvbm): KVBM MLA support optimization (#7015)


Signed-off-by: default avataraknvda <anikkulkarni@nvidia.com>
Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Co-authored-by: default avatarjthomson04 <jwillthomson19@gmail.com>
parent ac60a0e1
......@@ -235,6 +235,23 @@ To disable disk offload filtering:
export DYN_KVBM_DISABLE_DISK_OFFLOAD_FILTER=true
```
### NCCL Replicated Mode for MLA Models
For MLA (Multi-Layer Attention) models such as DeepSeek, KVBM can use **NCCL replicated mode** so that only rank 0 loads KV blocks from G2/G3 storage and then broadcasts them to all GPUs via NCCL. This avoids redundant loads and can improve performance when multiple GPUs share the same replicated KV cache.
**Enable NCCL MLA mode:**
```bash
export DYN_KVBM_NCCL_MLA_MODE=true
```
**Requirements:**
- MPI must be initialized (e.g., when launching with `mpirun` or equivalent) so that rank and world size are available for NCCL.
- For optimal broadcast-based replication, build KVBM with the NCCL feature: `cargo build -p kvbm --features nccl`. Without it, the connector falls back to worker-level replication (each GPU loads independently).
When disabled (default), each GPU loads KV blocks independently. Set `DYN_KVBM_NCCL_MLA_MODE=true` when running MLA models with KVBM to use the NCCL broadcast optimization.
## Enable and View KVBM Metrics
### Setup Monitoring Stack
......
......@@ -22,6 +22,7 @@ crate-type = ["cdylib", "rlib"]
[features]
default = ["block-manager"]
block-manager = ["dynamo-llm/block-manager", "dep:dlpark", "dep:cudarc"]
nccl = ["block-manager", "dynamo-llm/nccl", "cudarc/nccl"] # Enable NCCL collective operations for replicated mode
[dependencies]
dynamo-llm = { path = "../../llm" }
......
......@@ -88,6 +88,7 @@ Note that the default pip wheel built is not compatible with CUDA 13 at the mome
| `DYN_KVBM_METRICS_PORT` | Metrics port | `6880` |
| `DYN_KVBM_DISABLE_DISK_OFFLOAD_FILTER` | Disable disk offload filtering to remove SSD lifespan protection | `false` |
| `DYN_KVBM_HOST_OFFLOAD_PREFIX_MIN_PRIORITY` | Minimum priority (0-100) for CPU offload with contiguous (prefix) semantics: offloading stops at the first block below threshold, and all subsequent blocks are also skipped. Used for priority-based filtering. | `0` (no filtering) |
| `DYN_KVBM_NCCL_MLA_MODE` | Enable NCCL replicated mode for MLA (Multi-Layer Attention) models (e.g., DeepSeek). When set to `true`, rank 0 loads KV blocks from G2/G3 storage and broadcasts to all GPUs via NCCL instead of each GPU loading independently. Requires MPI and optional `nccl` feature for optimal behavior. | `false` |
#### Disk Storage Configuration
......
......@@ -3,6 +3,171 @@
from typing import Any, List, Optional
class NcclBootstrap:
"""
NCCL bootstrap for creating dedicated KVBM communicators.
This class provides methods to generate, serialize, deserialize,
and initialize NCCL communicators for KVBM's replicated mode.
Usage pattern:
1. Rank 0: Call `NcclBootstrap.generate(world_size)` to create a new unique ID
2. Rank 0: Call `serialize()` and broadcast to other ranks via MPI
3. Other ranks: Call `NcclBootstrap.deserialize(bytes)` to reconstruct
4. All ranks: Call `init_communicator(rank)` collectively to create the comm
"""
@staticmethod
def generate(world_size: int) -> "NcclBootstrap":
"""
Generate a new unique ID for NCCL communicator initialization.
This should only be called on rank 0.
Parameters:
-----------
world_size: int
The total number of ranks that will participate
Returns:
--------
NcclBootstrap
A new NcclBootstrap instance
"""
...
def serialize(self) -> bytes:
"""
Serialize the bootstrap data for distribution to other ranks.
Returns:
--------
bytes
The serialized bootstrap data (136 bytes)
"""
...
@staticmethod
def deserialize(data: bytes) -> "NcclBootstrap":
"""
Deserialize bootstrap data received from rank 0.
Parameters:
-----------
data: bytes
The serialized bootstrap data (136 bytes)
Returns:
--------
NcclBootstrap
A new NcclBootstrap instance
"""
...
def init_communicator(self, rank: int) -> "NcclCommRef":
"""
Initialize the NCCL communicator.
IMPORTANT: This is a collective operation!
All ranks must call this function together with matching parameters.
The function will block until all ranks have called it.
Returns an owning NcclCommRef; pass it to workers so the comm is
kept alive. The communicator is destroyed when the last reference is dropped.
Parameters:
-----------
rank: int
This rank's ID (0 to world_size-1)
Returns:
--------
NcclCommRef
Owning reference; pass to KvbmWorker/PyTrtllmKvConnectorWorker
"""
...
def world_size(self) -> int:
"""
Get the world size for this bootstrap.
Returns:
--------
int
The world size
"""
...
class NcclCommRef:
"""
Owning reference to an NCCL communicator; calls ncclCommDestroy on drop.
Returned by NcclBootstrap.init_communicator. Pass to workers
(KvbmWorker, PyTrtllmKvConnectorWorker) so they keep the comm alive.
"""
def as_raw(self) -> int:
"""
Raw ncclComm_t pointer as an integer (borrowed; do not destroy).
"""
...
class KvbmWorker:
"""
A KVBM worker that handles block transfers.
"""
def __init__(
self,
num_device_blocks: int,
page_size: int,
tensors: List[Any],
device_id: int = 0,
dtype_width_bytes: int = 2,
drt: Optional[Any] = None,
layout_blocking: bool = False,
device_layout_type: Optional[Any] = None,
host_layout_type: Optional[Any] = None,
disk_layout_type: Optional[Any] = None,
rank: Optional[int] = None,
world_size: Optional[int] = None,
nccl_comm_ref: Optional["NcclCommRef"] = None,
) -> None:
"""
Create a KvbmWorker instance.
Parameters:
-----------
num_device_blocks: int
Number of device blocks to manage
page_size: int
Page size for blocks
tensors: List[Any]
List of tensor objects (e.g., torch.Tensor)
device_id: int
CUDA device ID, defaults to 0
dtype_width_bytes: int
Data type width in bytes, defaults to 2 (fp16)
drt: Optional[Any]
Distributed runtime, if applicable
layout_blocking: bool
Whether to block on layout initialization, defaults to False
device_layout_type: Optional[Any]
Layout type for device blocks
host_layout_type: Optional[Any]
Layout type for host blocks
disk_layout_type: Optional[Any]
Layout type for disk blocks
rank: Optional[int]
Rank for replicated mode (None = sharded mode)
world_size: Optional[int]
World size for replicated mode
nccl_comm_ref: Optional[NcclCommRef]
Owning NCCL comm ref for replicated mode (from NcclBootstrap.init_communicator)
"""
...
class Layer:
"""
A KV cache block layer
......@@ -220,3 +385,296 @@ class KvbmRequest:
def __init__(self, request_id: int, tokens: List[int], block_size: int) -> None:
...
class SchedulerOutput:
"""
Scheduler output containing information about scheduled requests.
"""
new_requests: List[Any]
cached_requests: List[Any]
num_scheduled_tokens: dict
def __init__(self) -> None:
...
class PyTrtllmKvConnectorWorker:
"""
TensorRT-LLM KV connector worker for KVBM integration.
This class handles KV cache operations on the worker side for TRT-LLM,
including registration of KV caches, offloading, and loading operations.
"""
def __init__(
self,
py_drt: Optional[Any],
trtllm_rank: str,
nccl_rank: Optional[int] = None,
world_size: Optional[int] = None,
nccl_comm_ref: Optional["NcclCommRef"] = None,
) -> None:
"""
Create a PyTrtllmKvConnectorWorker instance.
Parameters:
-----------
py_drt: Optional[Any]
The distributed runtime object (DistributedRuntime)
trtllm_rank: str
The TRT-LLM rank identifier
nccl_rank: Optional[int]
NCCL rank for replicated mode (None = sharded mode).
Required for MLA support optimization (replicated mode).
world_size: Optional[int]
World size for replicated mode.
Required for MLA support optimization.
nccl_comm_ref: Optional[NcclCommRef]
Owning NCCL comm ref from NcclBootstrap.init_communicator().
Required for MLA support optimization.
"""
...
def register_kv_caches(
self,
num_device_blocks: int,
page_size: int,
device_id: int,
dtype_width_bytes: int,
kv_cache_tensor: Any,
raw_event_handles: List[int],
) -> None:
"""
Register KV cache tensors with the connector worker.
Parameters:
-----------
num_device_blocks: int
Number of device blocks to manage
page_size: int
Page size for blocks
device_id: int
CUDA device ID
dtype_width_bytes: int
Data type width in bytes (e.g., 2 for fp16)
kv_cache_tensor: Any
The KV cache tensor (torch.Tensor)
raw_event_handles: List[int]
List of raw CUDA event handles
"""
...
def bind_connector_meta(self, metadata: bytes) -> None:
"""
Bind connector metadata from the leader.
Parameters:
-----------
metadata: bytes
Serialized connector metadata
"""
...
def execute_offload_operations(self) -> None:
"""
Execute pending offload operations.
"""
...
def save_kv_layer(self, layer_idx: int) -> None:
"""
Save a KV cache layer.
Parameters:
-----------
layer_idx: int
Index of the layer to save
"""
...
def start_load_kv(self) -> None:
"""
Start loading KV cache data.
"""
...
def get_finished(
self,
finished_gen_req_ids: List[int],
started_loading_req_ids: List[int],
) -> tuple:
"""
Get finished offloading and onboarding request IDs.
Parameters:
-----------
finished_gen_req_ids: List[int]
List of request IDs that have finished generation
started_loading_req_ids: List[int]
List of request IDs that have started loading
Returns:
--------
tuple
A tuple of (finished_offloading, finished_onboarding) request ID lists
"""
...
def submit_offload_on_event(self, event: int) -> None:
"""
Submit offload operations to be executed when the given event completes.
Parameters:
-----------
event: int
Raw CUDA event handle
"""
...
class PyTrtllmKvConnectorLeader:
"""
TensorRT-LLM KV connector leader for KVBM integration.
This class handles KV cache coordination on the leader side for TRT-LLM,
including slot management, token matching, and metadata building.
"""
def __init__(
self,
worker_id: int,
drt: Optional[Any],
page_size: int,
leader: Any,
consolidator_trtllm_endpoint: Optional[str] = None,
consolidator_output_endpoint: Optional[str] = None,
) -> None:
"""
Create a PyTrtllmKvConnectorLeader instance.
Parameters:
-----------
worker_id: int
The worker ID for this leader
drt: Optional[Any]
The distributed runtime object (currently unused)
page_size: int
Page size for blocks
leader: Any
The KVBM leader object (PyKvbmLeader)
consolidator_trtllm_endpoint: Optional[str]
TRT-LLM consolidator endpoint
consolidator_output_endpoint: Optional[str]
Output consolidator endpoint
"""
...
def get_num_new_matched_tokens(
self,
request_id: str,
request_num_tokens: int,
num_computed_tokens: int,
) -> tuple:
"""
Get the number of newly matched tokens for a request.
Parameters:
-----------
request_id: str
The request identifier
request_num_tokens: int
Total number of tokens in the request
num_computed_tokens: int
Number of already computed tokens
Returns:
--------
tuple
A tuple of (num_matched_tokens, is_complete)
"""
...
def update_state_after_alloc(
self,
request_id: str,
block_ids: List[int],
context_current_position: int,
) -> None:
"""
Update state after block allocation.
Parameters:
-----------
request_id: str
The request identifier
block_ids: List[int]
List of allocated block IDs
context_current_position: int
Current context position
"""
...
def build_connector_metadata(self, scheduler_output: SchedulerOutput) -> bytes:
"""
Build connector metadata from scheduler output.
Parameters:
-----------
scheduler_output: SchedulerOutput
The scheduler output
Returns:
--------
bytes
Serialized connector metadata
"""
...
def request_finished(self, request_id: str, block_ids: List[int]) -> bool:
"""
Mark a request as finished.
Parameters:
-----------
request_id: str
The request identifier
block_ids: List[int]
List of block IDs used by the request
Returns:
--------
bool
True if the request was successfully marked as finished
"""
...
def has_slot(self, request_id: str) -> bool:
"""
Check if a slot exists for the given request.
Parameters:
-----------
request_id: str
The request identifier
Returns:
--------
bool
True if a slot exists
"""
...
def create_slot(self, request: KvbmRequest, tokens: List[int]) -> None:
"""
Create a slot for a request.
Parameters:
-----------
request: KvbmRequest
The KVBM request object
tokens: List[int]
List of tokens for the request
"""
...
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import os
from typing import Optional, Tuple
import torch
from kvbm.trtllm_integration.rust import KvConnectorWorker as RustKvConnectorWorker
......@@ -15,6 +16,99 @@ if is_dyn_runtime_enabled():
from dynamo.runtime import DistributedRuntime
def _get_mpi_info() -> Tuple[Optional[int], Optional[int]]:
"""Get MPI rank and world_size if MPI is initialized.
Returns:
Tuple of (rank, world_size), or (None, None) if MPI is not available/initialized.
"""
try:
from mpi4py import MPI
if MPI.Is_initialized():
comm = MPI.COMM_WORLD
return comm.Get_rank(), comm.Get_size()
except ImportError:
pass
except Exception as e:
logger.warning(f"Failed to get MPI info: {e}")
return None, None
def _create_kvbm_nccl_comm(rank: int, world_size: int):
"""Create a dedicated NCCL communicator for KVBM using MPI for bootstrap.
This function creates an NCCL communicator that is separate from any other
communicators (e.g., TRT-LLM's). The bootstrap uses MPI to distribute the
unique ID from rank 0 to all other ranks.
Args:
rank: This process's rank (0 to world_size-1)
world_size: Total number of ranks
Returns:
NcclCommRef: Owning reference; pass to the worker so the comm is
kept alive and destroyed when the worker is done.
Raises:
ImportError: If mpi4py or NcclBootstrap is not available
RuntimeError: If NCCL initialization fails
"""
from mpi4py import MPI
try:
from kvbm._core import NcclBootstrap
except ImportError:
raise ImportError(
"NcclBootstrap not available. "
"Make sure kvbm was built with the 'nccl' feature enabled."
)
comm = MPI.COMM_WORLD
# Rank 0 generates unique ID
if rank == 0:
bootstrap = NcclBootstrap.generate(world_size)
bootstrap_data = bootstrap.serialize()
else:
bootstrap_data = None
# Broadcast bootstrap data to all ranks
logger.info(
f"KVBM: Rank {rank} entering bcast (data_len={len(bootstrap_data) if bootstrap_data else 0})"
)
bootstrap_data = comm.bcast(bootstrap_data, root=0)
logger.info(
f"KVBM: Rank {rank} received bootstrap data (len={len(bootstrap_data)})"
)
# Non-rank-0 deserializes the data
if rank != 0:
bootstrap = NcclBootstrap.deserialize(bootstrap_data)
logger.info(f"KVBM: Rank {rank} bootstrap world_size={bootstrap.world_size()}")
# Trust the framework (TRT-LLM / MPI launcher) to have already
# set the correct CUDA device for this rank, either via
# CUDA_VISIBLE_DEVICES or its own initialization.
current_device = torch.cuda.current_device()
logger.info(
f"KVBM: Rank {rank} on CUDA device {current_device} "
f"(device_count={torch.cuda.device_count()})"
)
logger.info(f"KVBM: Rank {rank} waiting at MPI barrier " "before ncclCommInitRank")
comm.Barrier()
logger.info(f"KVBM: Rank {rank} passed barrier, " "calling ncclCommInitRank")
# All ranks collectively initialize (must be called together).
# This is a blocking collective operation; returns owning NcclCommRef.
nccl_comm_ref = bootstrap.init_communicator(rank)
logger.info(f"KVBM: Rank {rank} created dedicated NCCL communicator")
return nccl_comm_ref
class DynamoKVBMConnectorWorker(KvCacheConnectorWorker):
def _callable_object(self) -> callable:
assert (
......@@ -44,7 +138,64 @@ class DynamoKVBMConnectorWorker(KvCacheConnectorWorker):
mappings = self._llm_args.parallel_config.to_mapping()
self.rank = mappings.rank
self._connector = RustKvConnectorWorker(self.drt, str(self.rank))
# NCCL replicated mode for MLA support - controlled by feature flag
# Set DYN_KVBM_NCCL_MLA_MODE=true to enable NCCL broadcast optimization for MLA models
nccl_rank, nccl_world_size, nccl_comm_ref = None, None, None
enable_nccl_mla = os.environ.get("DYN_KVBM_NCCL_MLA_MODE", "false").lower() in (
"true",
"1",
"yes",
)
if enable_nccl_mla:
logger.info("KVBM NCCL MLA mode enabled via DYN_KVBM_NCCL_MLA_MODE")
nccl_rank, nccl_world_size = _get_mpi_info()
else:
logger.info(
"KVBM NCCL MLA mode disabled. Set DYN_KVBM_NCCL_MLA_MODE=true to enable "
"NCCL broadcast optimization for MLA models (e.g., DeepSeek)."
)
if enable_nccl_mla and nccl_rank is not None and nccl_world_size is not None:
try:
nccl_comm_ref = _create_kvbm_nccl_comm(nccl_rank, nccl_world_size)
logger.info(
f"KVBM MLA support: NCCL broadcast optimization enabled. "
f"Rank {nccl_rank}/{nccl_world_size}: only rank 0 loads "
f"from G2/G3 storage, then broadcasts to all GPUs."
)
except ImportError:
logger.warning(
"KVBM MLA support: NCCL not compiled. Using worker-level "
"replication (each GPU loads independently). For optimal "
"broadcast-based replication, rebuild with: "
"cargo build -p kvbm --features nccl"
)
nccl_rank, nccl_world_size, nccl_comm_ref = None, None, None
except Exception as e:
logger.warning(
"KVBM MLA support: _create_kvbm_nccl_comm failed (nccl_rank=%s, "
"nccl_world_size=%s). MLA broadcast disabled; using worker-level "
"replication (each GPU loads independently). Error: %s",
nccl_rank,
nccl_world_size,
e,
)
nccl_rank, nccl_world_size, nccl_comm_ref = None, None, None
elif enable_nccl_mla:
logger.info(
"KVBM: MPI not available, using standard sharded mode. "
"For NCCL replicated mode, ensure MPI is initialized."
)
# else: NCCL MLA mode disabled, no additional logging needed
self._connector = RustKvConnectorWorker(
self.drt,
str(self.rank),
nccl_rank=nccl_rank,
world_size=nccl_world_size,
nccl_comm_ref=nccl_comm_ref,
)
self.event = torch.cuda.Event()
# Default to old way of processing offload
......
......@@ -2,7 +2,15 @@
# SPDX-License-Identifier: Apache-2.0
"""
Loader for the Rust-based TensorRT-LLM integration objects, using objects from _vllm_integration for now
Rust-based TensorRT-LLM integration loader.
Uses objects from _vllm_integration module. Type stubs in kvbm/_core.pyi.
KvConnectorWorker (PyTrtllmKvConnectorWorker) signature:
(py_drt, trtllm_rank, nccl_rank=None, world_size=None, nccl_comm_ref=None)
The nccl_rank, world_size, and nccl_comm_ref parameters enable NCCL replicated mode
for MLA (Multi-head Latent Attention) support with broadcast-based KV cache transfers.
"""
try:
......@@ -16,13 +24,16 @@ try:
BlockStates = getattr(_vllm_integration, "BlockStates")
SlotUpdate = getattr(_vllm_integration, "SlotUpdate")
# TRT-LLM connector classes with NCCL replicated mode support
# KvConnectorWorker: optional nccl_rank, world_size, nccl_comm_ref for MLA support
KvConnectorWorker = getattr(_vllm_integration, "PyTrtllmKvConnectorWorker")
KvConnectorLeader = getattr(_vllm_integration, "PyTrtllmKvConnectorLeader")
SchedulerOutput = getattr(_vllm_integration, "SchedulerOutput")
except ImportError:
print(
"Failed to import Dynamo KVBM. TensorRT-LLM integration will not be available."
"Failed to import Dynamo KVBM. "
"TensorRT-LLM integration will not be available."
)
KvbmRequest = None
KvbmBlockList = None
......
......@@ -30,6 +30,9 @@ pub fn add_to_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<controller::BlockPoolStatus>()?;
m.add_class::<controller::ResetBlocksResponse>()?;
m.add_class::<distributed::PyNcclBootstrap>()?;
m.add_class::<distributed::PyNcclCommRef>()?;
vllm::add_to_module(m)?;
Ok(())
......
......@@ -10,3 +10,7 @@ mod worker;
pub use leader::KvbmLeader;
pub use utils::{get_leader_zmq_ack_url, get_leader_zmq_pub_url};
pub use worker::{KvbmWorker, PyLayoutType, VllmTensor};
#[cfg(feature = "nccl")]
pub use worker::{PyNcclBootstrap, PyNcclCommRef};
#[cfg(not(feature = "nccl"))]
pub use worker::{PyNcclBootstrap, PyNcclCommRef};
......@@ -8,11 +8,64 @@ use utils::{get_leader_zmq_ack_url, get_leader_zmq_pub_url};
use llm_rs::block_manager::distributed::{
BlockTransferHandler as RustBlockTransferHandler, KvbmWorker as KvbmWorkerImpl,
KvbmWorkerConfig,
KvbmWorkerConfig, NcclConfig,
};
#[cfg(feature = "nccl")]
use llm_rs::block_manager::distributed::{NcclBootstrap, NcclCommOwned};
use llm_rs::block_manager::layout::LayoutType;
use llm_rs::block_manager::storage::torch::{TorchDevice, TorchTensor};
/// Build NcclConfig from Python parameters.
///
/// Returns an error if NCCL parameters are provided but the NCCL feature is not enabled.
fn build_nccl_config(
rank: Option<i32>,
world_size: Option<i32>,
nccl_comm_ptr: Option<usize>,
) -> anyhow::Result<NcclConfig> {
// Check if the user is trying to use replicated mode
let wants_replicated = rank.is_some() || world_size.is_some() || nccl_comm_ptr.is_some();
#[cfg(feature = "nccl")]
{
match (rank, world_size, nccl_comm_ptr) {
(Some(r), Some(ws), Some(ptr)) if ptr != 0 => {
use cudarc::nccl::sys::ncclComm_t;
Ok(unsafe { NcclConfig::enabled(ptr as ncclComm_t, r, ws) })
}
(Some(r), Some(ws), Some(0)) => anyhow::bail!(
"NCCL replicated mode requires a valid communicator: rank={}, world_size={}, nccl_comm_ptr=0 (invalid). \
Provide a non-null nccl_comm_ptr or omit rank/world_size/nccl_comm_ptr for sharded mode.",
r,
ws
),
(r, ws, ptr) if wants_replicated => anyhow::bail!(
"NCCL replicated mode requires rank, world_size, and nccl_comm_ptr together; \
partial or invalid configuration is not allowed. Got rank={:?}, world_size={:?}, nccl_comm_ptr={:?}. \
Provide all three (with nccl_comm_ptr != 0) or omit all for sharded mode.",
r,
ws,
ptr
),
_ => Ok(NcclConfig::disabled()),
}
}
#[cfg(not(feature = "nccl"))]
{
if wants_replicated {
anyhow::bail!(
"NCCL replicated mode requested (rank={:?}, world_size={:?}, nccl_comm_ptr={:?}) \
but kvbm was not built with the 'nccl' feature enabled. \
Please rebuild with 'nccl' feature or use sharded mode (omit rank/world_size/nccl_comm_ptr).",
rank,
world_size,
nccl_comm_ptr
);
}
Ok(NcclConfig::disabled())
}
}
/// A wrapper around a layout type.
/// This is used to convert between the Python and Rust layout types.
#[pyclass(eq, eq_int)]
......@@ -132,6 +185,9 @@ impl BlockTransferHandler {
pub struct KvbmWorker {
inner: Arc<Mutex<KvbmWorkerImpl>>,
_drt: Option<Arc<rs::DistributedRuntime>>,
/// Keeps the NCCL communicator alive for the worker lifetime; dropped with the worker.
#[cfg(feature = "nccl")]
_nccl_comm: Option<Arc<NcclCommOwned>>,
}
impl KvbmWorker {
......@@ -143,8 +199,8 @@ impl KvbmWorker {
#[pymethods]
impl KvbmWorker {
#[new]
#[pyo3(signature = (num_device_blocks, page_size, tensors, device_id=0, dtype_width_bytes=2, drt=None, layout_blocking=false, device_layout_type=None, host_layout_type=None, disk_layout_type=None, rank=None, world_size=None, nccl_comm_ref=None))]
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (num_device_blocks, page_size, tensors, device_id=0, dtype_width_bytes=2, drt=None, layout_blocking=false, device_layout_type=None, host_layout_type=None, disk_layout_type=None))]
fn new(
num_device_blocks: usize,
page_size: usize,
......@@ -156,6 +212,9 @@ impl KvbmWorker {
device_layout_type: Option<PyLayoutType>,
host_layout_type: Option<PyLayoutType>,
disk_layout_type: Option<PyLayoutType>,
rank: Option<i32>,
world_size: Option<i32>,
nccl_comm_ref: Option<PyObject>,
) -> PyResult<Self> {
let drt: Option<Arc<rs::DistributedRuntime>> = Python::with_gil(|py| {
if let Some(obj) = drt {
......@@ -174,6 +233,27 @@ impl KvbmWorker {
vllm_tensors.push(Arc::new(vllm_tensor));
}
// Own the NCCL communicator for the worker lifetime; NcclConfig gets a borrowed handle.
#[cfg(feature = "nccl")]
let _nccl_comm = nccl_comm_ref.as_ref().and_then(|obj| {
Python::with_gil(|py| {
obj.downcast_bound::<PyNcclCommRef>(py)
.ok()
.map(|r| r.borrow().get_arc())
})
});
#[cfg(feature = "nccl")]
let nccl_comm_ptr = _nccl_comm.as_ref().map(|a| a.as_raw() as usize);
#[cfg(not(feature = "nccl"))]
let nccl_comm_ptr: Option<usize> = None;
#[cfg(not(feature = "nccl"))]
let _ = nccl_comm_ref;
// Build NcclConfig from owned comm (borrowed handle only)
let nccl_config = build_nccl_config(rank, world_size, nccl_comm_ptr).map_err(to_pyerr)?;
// When NCCL is disabled, pass None for rank/world_size so the worker is consistently in sharded mode.
let worker_rank = if nccl_config.is_enabled() { rank } else { None };
let config = KvbmWorkerConfig::builder()
.cancel_token(get_current_cancel_token())
.num_device_blocks(num_device_blocks)
......@@ -198,6 +278,8 @@ impl KvbmWorker {
)
.leader_pub_url(get_leader_zmq_pub_url())
.leader_ack_url(get_leader_zmq_ack_url())
.rank(worker_rank)
.nccl_config(nccl_config)
.build()
.map_err(to_pyerr)?;
......@@ -211,6 +293,182 @@ impl KvbmWorker {
Ok(Self {
inner: Arc::new(Mutex::new(worker)),
_drt: drt,
#[cfg(feature = "nccl")]
_nccl_comm,
})
}
}
/// Owning wrapper for an NCCL communicator; calls ncclCommDestroy on drop.
///
/// Returned by NcclBootstrap.init_communicator. Pass this object to workers
/// (e.g. KvbmWorker, PyTrtllmKvConnectorWorker) so they keep the comm alive.
/// The raw handle is exposed via as_raw() for NcclConfig; ownership stays here.
#[cfg(feature = "nccl")]
#[pyclass(name = "NcclCommRef")]
pub struct PyNcclCommRef {
inner: Arc<NcclCommOwned>,
}
#[cfg(feature = "nccl")]
#[pymethods]
impl PyNcclCommRef {
/// Raw ncclComm_t pointer as an integer (borrowed; do not destroy).
/// Used by NcclConfig; the communicator is owned by this ref.
fn as_raw(&self) -> usize {
self.inner.as_raw() as usize
}
}
#[cfg(feature = "nccl")]
impl PyNcclCommRef {
/// Clone the inner Arc so a worker can hold ownership until drop.
pub fn get_arc(&self) -> Arc<NcclCommOwned> {
Arc::clone(&self.inner)
}
}
/// Python wrapper for NCCL bootstrap functionality.
///
/// This class provides methods to generate, serialize, deserialize,
/// and initialize NCCL communicators for KVBM's replicated mode.
///
/// Usage pattern:
/// 1. Rank 0: Call `NcclBootstrap.generate(world_size)` to create a new unique ID
/// 2. Rank 0: Call `serialize()` and broadcast to other ranks via MPI
/// 3. Other ranks: Call `NcclBootstrap.deserialize(bytes)` to reconstruct
/// 4. All ranks: Call `init_communicator(rank)` collectively to create the comm
#[cfg(feature = "nccl")]
#[pyclass(name = "NcclBootstrap")]
pub struct PyNcclBootstrap {
inner: NcclBootstrap,
}
#[cfg(feature = "nccl")]
#[pymethods]
impl PyNcclBootstrap {
/// Generate a new unique ID for NCCL communicator initialization.
/// This should only be called on rank 0.
///
/// Args:
/// world_size: The total number of ranks that will participate
///
/// Returns:
/// A new PyNcclBootstrap instance
#[staticmethod]
fn generate(world_size: i32) -> PyResult<Self> {
let inner = NcclBootstrap::generate(world_size).map_err(to_pyerr)?;
Ok(Self { inner })
}
/// Serialize the bootstrap data for distribution to other ranks.
///
/// Returns:
/// bytes: The serialized bootstrap data (136 bytes)
fn serialize<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, pyo3::types::PyBytes>> {
let bytes = self.inner.serialize();
Ok(pyo3::types::PyBytes::new(py, &bytes))
}
/// Deserialize bootstrap data received from rank 0.
///
/// Args:
/// data: The serialized bootstrap data (136 bytes)
///
/// Returns:
/// A new PyNcclBootstrap instance
#[staticmethod]
fn deserialize(data: &[u8]) -> PyResult<Self> {
let inner = NcclBootstrap::deserialize(data).map_err(to_pyerr)?;
Ok(Self { inner })
}
/// Initialize the NCCL communicator.
///
/// IMPORTANT: This is a collective operation!
/// All ranks must call this function together with matching parameters.
/// The function will block until all ranks have called it.
///
/// Returns an owning NcclCommRef; keep it alive for the lifetime of any
/// worker using this communicator. The communicator is destroyed when
/// the last reference is dropped.
///
/// Args:
/// rank: This rank's ID (0 to world_size-1)
///
/// Returns:
/// NcclCommRef: Owning wrapper; use .as_raw() for the raw handle if needed
fn init_communicator(&self, rank: i32) -> PyResult<PyNcclCommRef> {
let comm = self.inner.init_communicator(rank).map_err(to_pyerr)?;
let owned = unsafe { NcclCommOwned::from_raw(comm) };
Ok(PyNcclCommRef {
inner: Arc::new(owned),
})
}
/// Get the world size for this bootstrap.
fn world_size(&self) -> i32 {
self.inner.world_size()
}
}
// ---------------------------------------------------------------------------
// Stub implementations when nccl feature is disabled (match .pyi; raise on use)
// ---------------------------------------------------------------------------
#[cfg(not(feature = "nccl"))]
const NCCL_UNAVAILABLE_MSG: &str = "kvbm was not built with the 'nccl' feature. Rebuild with --features nccl to use NcclBootstrap/NcclCommRef.";
#[cfg(not(feature = "nccl"))]
#[pyclass(name = "NcclCommRef")]
pub struct PyNcclCommRef;
#[cfg(not(feature = "nccl"))]
#[pymethods]
impl PyNcclCommRef {
fn as_raw(&self) -> PyResult<usize> {
Err(pyo3::exceptions::PyRuntimeError::new_err(
NCCL_UNAVAILABLE_MSG,
))
}
}
#[cfg(not(feature = "nccl"))]
#[pyclass(name = "NcclBootstrap")]
pub struct PyNcclBootstrap;
#[cfg(not(feature = "nccl"))]
#[pymethods]
impl PyNcclBootstrap {
#[staticmethod]
fn generate(_world_size: i32) -> PyResult<Self> {
Err(pyo3::exceptions::PyRuntimeError::new_err(
NCCL_UNAVAILABLE_MSG,
))
}
fn serialize<'py>(&self, _py: Python<'py>) -> PyResult<Bound<'py, pyo3::types::PyBytes>> {
Err(pyo3::exceptions::PyRuntimeError::new_err(
NCCL_UNAVAILABLE_MSG,
))
}
#[staticmethod]
fn deserialize(_data: &[u8]) -> PyResult<Self> {
Err(pyo3::exceptions::PyRuntimeError::new_err(
NCCL_UNAVAILABLE_MSG,
))
}
fn init_communicator(&self, _rank: i32) -> PyResult<PyNcclCommRef> {
Err(pyo3::exceptions::PyRuntimeError::new_err(
NCCL_UNAVAILABLE_MSG,
))
}
fn world_size(&self) -> PyResult<i32> {
Err(pyo3::exceptions::PyRuntimeError::new_err(
NCCL_UNAVAILABLE_MSG,
))
}
}
......@@ -10,6 +10,8 @@ use std::collections::HashSet;
use std::sync::{Arc, OnceLock};
use super::*;
#[cfg(feature = "nccl")]
use crate::block_manager::distributed::PyNcclCommRef;
use crate::block_manager::distributed::{get_leader_zmq_ack_url, get_leader_zmq_pub_url};
use crate::block_manager::vllm::connector::worker::event_sync_blocking;
use crate::{block_manager::distributed::VllmTensor, to_pyerr};
......@@ -19,7 +21,9 @@ use crate::{
extract_distributed_runtime_from_obj, get_current_cancel_token, get_current_tokio_handle,
};
use anyhow;
use dynamo_llm::block_manager::distributed::{KvbmWorker, KvbmWorkerConfig};
#[cfg(feature = "nccl")]
use dynamo_llm::block_manager::distributed::NcclCommOwned;
use dynamo_llm::block_manager::distributed::{KvbmWorker, KvbmWorkerConfig, NcclConfig};
use dynamo_llm::block_manager::layout::LayoutType;
use dynamo_llm::block_manager::storage::torch::TorchTensor;
use dynamo_runtime::utils::task::CriticalTaskExecutionHandle;
......@@ -76,10 +80,26 @@ pub struct KvConnectorWorker {
/// cuda events created by the python side
layer_events: Vec<u64>,
/// NCCL rank for replicated mode (None = sharded mode)
nccl_rank: Option<i32>,
/// World size for NCCL replicated mode
world_size: Option<i32>,
/// Owned NCCL communicator; kept alive for worker lifetime, NcclConfig uses borrowed handle.
#[cfg(feature = "nccl")]
nccl_comm: Option<Arc<NcclCommOwned>>,
}
impl KvConnectorWorker {
fn new(drt: Option<Arc<DistributedRuntime>>, trtllm_rank: String) -> anyhow::Result<Self> {
fn new(
drt: Option<Arc<DistributedRuntime>>,
trtllm_rank: String,
nccl_rank: Option<i32>,
world_size: Option<i32>,
nccl_comm_ref: Option<pyo3::PyObject>,
) -> anyhow::Result<Self> {
let runtime = get_current_tokio_handle();
let (scheduler, worker_client, transfer_client) =
......@@ -96,9 +116,27 @@ impl KvConnectorWorker {
)?
.detach();
#[cfg(feature = "nccl")]
let nccl_comm = nccl_comm_ref.as_ref().and_then(|obj| {
pyo3::Python::with_gil(|py| {
obj.downcast_bound::<PyNcclCommRef>(py)
.ok()
.map(|r| r.borrow().get_arc())
})
});
#[cfg(not(feature = "nccl"))]
let _ = nccl_comm_ref;
tracing::info!(
"KvConnectorWorker initialized with worker_rank: {}",
trtllm_rank
"KvConnectorWorker initialized with worker_rank: {}, nccl_rank: {:?}, world_size: {:?}, nccl_comm_ref: {}",
trtllm_rank,
nccl_rank,
world_size,
if nccl_comm_ref.is_some() {
"Some"
} else {
"None"
}
);
Ok(Self {
......@@ -114,10 +152,72 @@ impl KvConnectorWorker {
iteration: 0,
layers_complete: 0,
layer_events: Vec::new(),
nccl_rank,
world_size,
#[cfg(feature = "nccl")]
nccl_comm,
})
}
}
/// Build NcclConfig from the provided parameters.
///
/// Returns an error if NCCL parameters are partially provided or if the NCCL
/// feature is not enabled but replicated mode was requested. This matches the
/// validation in the distributed worker binding.
fn build_nccl_config(
nccl_rank: Option<i32>,
world_size: Option<i32>,
nccl_comm_ptr: Option<usize>,
) -> anyhow::Result<NcclConfig> {
let wants_replicated = nccl_rank.is_some() || world_size.is_some() || nccl_comm_ptr.is_some();
#[cfg(feature = "nccl")]
{
match (nccl_rank, world_size, nccl_comm_ptr) {
(Some(r), Some(ws), Some(ptr)) if ptr != 0 => {
tracing::info!(
"Creating NCCL config for replicated mode: rank={}, world_size={}, comm_ptr={:#x}",
r,
ws,
ptr
);
use cudarc::nccl::sys::ncclComm_t;
Ok(unsafe { NcclConfig::enabled(ptr as ncclComm_t, r, ws) })
}
(Some(r), Some(ws), Some(0)) => anyhow::bail!(
"NCCL replicated mode requires a valid communicator: rank={}, world_size={}, \
nccl_comm_ptr=0 (invalid). Provide a non-null nccl_comm_ptr or omit all for sharded mode.",
r,
ws
),
(r, ws, ptr) if wants_replicated => anyhow::bail!(
"NCCL replicated mode requires rank, world_size, and nccl_comm_ptr together; \
partial configuration is not allowed. Got rank={:?}, world_size={:?}, \
nccl_comm_ptr={:?}. Provide all three or omit all for sharded mode.",
r,
ws,
ptr
),
_ => Ok(NcclConfig::disabled()),
}
}
#[cfg(not(feature = "nccl"))]
{
if wants_replicated {
anyhow::bail!(
"NCCL replicated mode requested (rank={:?}, world_size={:?}, nccl_comm_ptr={:?}) \
but kvbm was not built with the 'nccl' feature. Rebuild with --features nccl \
or omit rank/world_size/nccl_comm_ptr for sharded mode.",
nccl_rank,
world_size,
nccl_comm_ptr
);
}
Ok(NcclConfig::disabled())
}
}
impl Worker for KvConnectorWorker {
fn register_kv_caches(
&mut self,
......@@ -135,6 +235,19 @@ impl Worker for KvConnectorWorker {
let kv_cache_tensors = vec![kv_cache_tensor as Arc<dyn TorchTensor>];
// Build NCCL config from owned comm (borrowed handle only)
#[cfg(feature = "nccl")]
let nccl_comm_ptr = self.nccl_comm.as_ref().map(|a| a.as_raw() as usize);
#[cfg(not(feature = "nccl"))]
let nccl_comm_ptr: Option<usize> = None;
let nccl_config = build_nccl_config(self.nccl_rank, self.world_size, nccl_comm_ptr)?;
// When NCCL is disabled, pass None for rank/world_size so the worker is consistently in sharded mode.
let nccl_rank = if nccl_config.is_enabled() {
self.nccl_rank
} else {
None
};
let config = KvbmWorkerConfig::builder()
.cancel_token(get_current_cancel_token())
.num_device_blocks(num_device_blocks)
......@@ -148,6 +261,8 @@ impl Worker for KvConnectorWorker {
.leader_pub_url(get_leader_zmq_pub_url())
.leader_ack_url(get_leader_zmq_ack_url())
.scheduler_client(Some(self.transfer_client.clone()))
.rank(nccl_rank)
.nccl_config(nccl_config)
.build()?;
self.layer_events = raw_event_handles;
......@@ -444,8 +559,14 @@ pub struct PyTrtllmKvConnectorWorker {
#[pymethods]
impl PyTrtllmKvConnectorWorker {
#[new]
#[pyo3(signature = (py_drt, trtllm_rank))]
pub fn new(py_drt: Option<PyObject>, trtllm_rank: String) -> PyResult<Self> {
#[pyo3(signature = (py_drt, trtllm_rank, nccl_rank=None, world_size=None, nccl_comm_ref=None))]
pub fn new(
py_drt: Option<PyObject>,
trtllm_rank: String,
nccl_rank: Option<i32>,
world_size: Option<i32>,
nccl_comm_ref: Option<PyObject>,
) -> PyResult<Self> {
let drt: Option<Arc<DistributedRuntime>> = Python::with_gil(|py| {
if let Some(obj) = py_drt {
extract_distributed_runtime_from_obj(py, obj)
......@@ -454,8 +575,10 @@ impl PyTrtllmKvConnectorWorker {
}
})?;
let connector_worker: Box<dyn Worker> =
Box::new(KvConnectorWorker::new(drt, trtllm_rank).map_err(to_pyerr)?);
let connector_worker: Box<dyn Worker> = Box::new(
KvConnectorWorker::new(drt, trtllm_rank, nccl_rank, world_size, nccl_comm_ref)
.map_err(to_pyerr)?,
);
Ok(Self { connector_worker })
}
......
......@@ -21,9 +21,10 @@ testing-full = ["testing-cuda", "testing-nixl"]
testing-cuda = ["dep:cudarc", "dynamo-memory/testing-cuda"]
testing-nixl = ["dep:nixl-sys", "dynamo-memory/testing-nixl"]
testing-etcd = []
block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:nix", "dep:aligned-vec"]
# Forward the NVTX feature to dynamo-runtime (build with --features nvtx or dynamo-llm/nvtx)
nvtx = ["dynamo-runtime/nvtx"]
block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:nix", "dep:aligned-vec"]
nccl = ["dep:cudarc", "cudarc/nccl"] # Enable NCCL collective operations
block-manager-bench = ["block-manager", "testing-full", "dep:clap", "dep:indicatif"]
cuda = ["dep:cudarc"]
integration = ["dynamo-runtime/integration"]
......
......@@ -4,6 +4,8 @@
pub mod context;
mod cuda;
mod memcpy;
#[cfg(feature = "nccl")]
mod nccl;
mod nixl;
mod strategy;
......@@ -23,6 +25,9 @@ pub use crate::block_manager::storage::{CudaAccessible, Local, Remote};
pub use async_trait::async_trait;
pub use context::{PoolConfig, TransferContext};
#[cfg(feature = "nccl")]
pub use nccl::{NcclGroup, bcast_block, bcast_layer};
/// A block that can be the target of a write
pub trait Writable {}
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! NCCL collective broadcast operations for block data.
//!
//! This module provides functions for broadcasting block data across multiple
//! GPUs using NCCL collective operations.
use super::*;
use std::cell::Cell;
use std::ffi::c_void;
use std::ops::Range;
use anyhow::{Context, Result};
use cudarc::driver::sys::CUstream;
use cudarc::nccl::sys::{
ncclBcast, ncclComm_t, ncclDataType_t, ncclGroupEnd, ncclGroupStart, ncclResult_t,
};
/// Check NCCL result and convert to anyhow::Result
fn check_nccl_result(result: ncclResult_t) -> Result<()> {
if result == ncclResult_t::ncclSuccess {
Ok(())
} else {
anyhow::bail!("NCCL error: {:?}", result)
}
}
/// RAII guard for NCCL group operations.
///
/// Calls `ncclGroupStart` in [`NcclGroup::new`] and `ncclGroupEnd` in [`NcclGroup::end`]
/// (or in [`Drop`] if [`NcclGroup::end`] was not called).
/// Use this to batch multiple NCCL operations efficiently.
///
/// **Call [`NcclGroup::end`] before dropping** so submission errors can be observed.
/// If you drop without calling `end()`, [`Drop`] will call `ncclGroupEnd()` and panic on error.
///
/// # Example
/// ```ignore
/// let mut group = unsafe { NcclGroup::new()? };
/// unsafe { bcast_block(&block1, root, comm, stream)?; }
/// unsafe { bcast_block(&block2, root, comm, stream)?; }
/// group.end()?; // Submit the group; call before drop to observe errors
/// drop(group);
/// ```
///
/// # Safety
/// Creating an `NcclGroup` is unsafe because:
/// - All ranks must create and drop the group collectively
/// - NCCL operations between creation and drop must be valid
pub struct NcclGroup {
/// Tracks whether `ncclGroupEnd` has been successfully called (via `end()` or will be in `Drop`).
ended: Cell<bool>,
}
impl NcclGroup {
/// Start a new NCCL group.
///
/// Calls `ncclGroupStart`. All ranks must call this collectively.
///
/// # Safety
/// - All ranks must call this collectively
/// - Call [`NcclGroup::end`] before drop to observe submission errors; the group must be ended before any synchronization
pub unsafe fn new() -> Result<Self> {
let result = unsafe { ncclGroupStart() };
check_nccl_result(result).context("ncclGroupStart failed")?;
Ok(Self {
ended: Cell::new(false),
})
}
/// End the NCCL group and submit all queued operations.
///
/// Calls `ncclGroupEnd()`. Call this before dropping the guard so submission
/// errors can be observed. If this returns `Ok(())`, [`Drop`] will not call
/// `ncclGroupEnd` again. If you drop without calling `end()`, [`Drop`] will
/// call `ncclGroupEnd()` and panic on error.
///
/// Returns an error if the group was already ended or if `ncclGroupEnd` fails.
pub fn end(&self) -> Result<()> {
if self.ended.get() {
anyhow::bail!("NcclGroup::end called twice");
}
let result = unsafe { ncclGroupEnd() };
check_nccl_result(result).context("ncclGroupEnd failed")?;
self.ended.set(true);
Ok(())
}
}
impl Drop for NcclGroup {
fn drop(&mut self) {
if self.ended.get() {
return; // end() already called ncclGroupEnd successfully
}
// Safety: We started the group in NcclGroup::new (ncclGroupStart); we must end it.
// Panic on error so we do not silently swallow ncclGroupEnd failures.
let result = unsafe { ncclGroupEnd() };
if result != ncclResult_t::ncclSuccess {
panic!(
"ncclGroupEnd failed in NcclGroup drop: {:?}. Call NcclGroup::end() before drop to handle errors.",
result
);
}
}
}
/// Broadcast a block to all ranks.
///
/// If the block is fully contiguous, uses a single NCCL broadcast call.
/// Otherwise, falls back to layer-by-layer broadcast via [`bcast_layer`].
///
/// This function should be called from within an [`NcclGroup`] scope for
/// efficient batching of multiple broadcasts.
///
/// # Safety
/// - `comm` must be a valid NCCL communicator
/// - `stream` must be a valid CUDA stream
/// - All ranks must call this collectively with matching parameters
/// - The block's memory must be valid GPU memory accessible by the NCCL communicator
/// - Should be called within an [`NcclGroup`] scope
///
/// # Arguments
/// * `block` - The block to broadcast (source on root, destination on other ranks)
/// * `root` - The rank that owns the source data
/// * `comm` - The NCCL communicator
/// * `stream` - The CUDA stream to use for the operation
pub unsafe fn bcast_block<B>(block: &B, root: i32, comm: ncclComm_t, stream: CUstream) -> Result<()>
where
B: BlockDataProvider,
{
let data = block.block_data();
if data.is_fully_contiguous() {
let view = data.block_view().context("Failed to get block view")?;
let ptr = unsafe { view.as_ptr() } as usize;
let size = view.size();
let result = unsafe {
ncclBcast(
ptr as *mut c_void,
size,
ncclDataType_t::ncclChar,
root,
comm,
stream.cast(),
)
};
check_nccl_result(result).context("ncclBcast failed")
} else {
// Fall back to layer-by-layer broadcast for non-contiguous blocks
unsafe { bcast_layer(block, None, root, comm, stream) }
}
}
/// Broadcast block layers to all ranks.
///
/// Iterates over layer views and broadcasts each one. Use this when only a
/// subset of layers should be broadcast, or when the block layout is not
/// fully contiguous.
///
/// This function should be called from within an [`NcclGroup`] scope for
/// efficient batching of multiple broadcasts.
///
/// # Safety
/// - `comm` must be a valid NCCL communicator
/// - `stream` must be a valid CUDA stream
/// - All ranks must call this collectively with matching parameters
/// - The block's memory must be valid GPU memory accessible by the NCCL communicator
/// - Should be called within an [`NcclGroup`] scope
///
/// # Arguments
/// * `block` - The block containing layers to broadcast
/// * `layer_range` - Optional range of layers to broadcast. If None, broadcasts all layers.
/// * `root` - The rank that owns the source data
/// * `comm` - The NCCL communicator
/// * `stream` - The CUDA stream to use for the operation
pub unsafe fn bcast_layer<B>(
block: &B,
layer_range: Option<Range<usize>>,
root: i32,
comm: ncclComm_t,
stream: CUstream,
) -> Result<()>
where
B: BlockDataProvider,
{
let data = block.block_data();
let layer_range = layer_range.unwrap_or(0..data.num_layers());
for layer_idx in layer_range {
for outer_idx in 0..data.num_outer_dims() {
let view = data
.layer_view(layer_idx, outer_idx)
.context("Failed to get layer view")?;
let ptr = unsafe { view.as_ptr() } as usize;
let size = view.size();
let result = unsafe {
ncclBcast(
ptr as *mut c_void,
size,
ncclDataType_t::ncclChar,
root,
comm,
stream.cast(),
)
};
check_nccl_result(result).context("ncclBcast failed in layer loop")?;
}
}
Ok(())
}
......@@ -6,10 +6,14 @@ mod utils;
mod zmq;
mod leader;
#[cfg(feature = "nccl")]
mod nccl_bootstrap;
mod worker;
pub use leader::{KvbmLeader, KvbmLeaderConfig, KvbmLeaderNumBlocksConfig};
pub use transfer::BlockTransferHandler;
#[cfg(feature = "nccl")]
pub use nccl_bootstrap::{NcclBootstrap, NcclCommOwned};
pub use transfer::{BlockTransferHandler, NcclConfig};
pub use utils::{
BlockTransferPool, BlockTransferRequest, ConnectorRequestLeader, ConnectorTransferType,
};
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! NCCL bootstrap for creating dedicated KVBM communicators.
//!
//! This module provides infrastructure for bootstrapping NCCL communicators
//! that are dedicated to KVBM operations, separate from other runtime comms.
//!
//! The bootstrap pattern:
//! 1. Rank 0 generates a unique NCCL ID via `ncclGetUniqueId`
//! 2. The unique ID is broadcast to all ranks (via MPI or other mechanism)
//! 3. All ranks collectively call `ncclCommInitRank` to create the communicator
use anyhow::{Context, Result};
use cudarc::nccl::sys::{
ncclComm_t, ncclCommDestroy, ncclCommInitRankConfig, ncclConfig_t, ncclGetUniqueId,
ncclResult_t, ncclUniqueId,
};
/// Check NCCL result and convert to anyhow::Result
fn check_nccl_result(result: ncclResult_t, operation: &str) -> Result<()> {
if result == ncclResult_t::ncclSuccess {
Ok(())
} else {
// Provide detailed error information for debugging
let error_name = match result {
ncclResult_t::ncclUnhandledCudaError => "ncclUnhandledCudaError",
ncclResult_t::ncclSystemError => "ncclSystemError",
ncclResult_t::ncclInternalError => "ncclInternalError",
ncclResult_t::ncclInvalidArgument => "ncclInvalidArgument",
ncclResult_t::ncclInvalidUsage => "ncclInvalidUsage",
ncclResult_t::ncclRemoteError => "ncclRemoteError",
ncclResult_t::ncclInProgress => "ncclInProgress",
_ => "Unknown",
};
anyhow::bail!(
"{} failed with error: {} ({:?}). Check NCCL_DEBUG=INFO for more details.",
operation,
error_name,
result
)
}
}
/// NCCL bootstrap for creating dedicated KVBM communicator.
///
/// This struct holds the unique ID needed to initialize an NCCL communicator
/// across multiple ranks. The typical usage pattern is:
///
/// 1. Rank 0: Call `NcclBootstrap::generate(world_size)` to create a new unique ID
/// 2. Rank 0: Serialize with `serialize()` and broadcast to other ranks
/// 3. Other ranks: Call `NcclBootstrap::deserialize(bytes)` to reconstruct
/// 4. All ranks: Call `init_communicator(rank)` collectively to create the comm
///
/// # Example
/// ```ignore
/// // On rank 0:
/// let bootstrap = NcclBootstrap::generate(world_size)?;
/// let data = bootstrap.serialize();
/// // ... broadcast data via MPI ...
///
/// // On all ranks:
/// let bootstrap = if rank == 0 {
/// bootstrap
/// } else {
/// NcclBootstrap::deserialize(&received_data)?
/// };
///
/// // All ranks call this together:
/// let comm = bootstrap.init_communicator(rank)?;
/// ```
pub struct NcclBootstrap {
unique_id: ncclUniqueId,
world_size: i32,
}
impl NcclBootstrap {
/// Generate a new unique ID for NCCL communicator initialization.
/// This should only be called on rank 0.
///
/// # Arguments
/// * `world_size` - The total number of ranks that will participate
pub fn generate(world_size: i32) -> Result<Self> {
let mut unique_id = ncclUniqueId { internal: [0; 128] };
let result = unsafe { ncclGetUniqueId(&mut unique_id) };
check_nccl_result(result, "ncclGetUniqueId")?;
Ok(Self {
unique_id,
world_size,
})
}
/// Serialize the bootstrap data for distribution to other ranks.
/// Format: 4 bytes world_size (little endian) + 4 bytes padding + 128 bytes unique_id
pub fn serialize(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(136);
bytes.extend_from_slice(&self.world_size.to_le_bytes());
bytes.extend_from_slice(&[0u8; 4]); // padding for alignment
let internal_bytes: &[u8; 128] =
unsafe { &*self.unique_id.internal.as_ptr().cast::<[u8; 128]>() };
bytes.extend_from_slice(internal_bytes);
bytes
}
/// Deserialize bootstrap data received from rank 0.
///
/// # Arguments
/// * `bytes` - The serialized bootstrap data (136 bytes)
pub fn deserialize(bytes: &[u8]) -> Result<Self> {
anyhow::ensure!(
bytes.len() == 136,
"Invalid bootstrap data length: expected 136, got {}",
bytes.len()
);
let world_size = i32::from_le_bytes(
bytes[0..4]
.try_into()
.context("Failed to parse world_size")?,
);
let mut unique_id = ncclUniqueId { internal: [0; 128] };
// c_char is i8 on x86_64 but u8 on aarch64; use ptr copy to be portable
unsafe {
std::ptr::copy_nonoverlapping(
bytes[8..136].as_ptr(),
unique_id.internal.as_mut_ptr().cast::<u8>(),
128,
);
}
Ok(Self {
unique_id,
world_size,
})
}
/// Initialize the NCCL communicator.
///
/// # IMPORTANT: This is a collective operation!
/// All ranks must call this function together with matching parameters.
/// The function will block until all ranks have called it.
///
/// # Arguments
/// * `rank` - This rank's ID (0 to world_size-1)
///
/// # Returns
/// The raw `ncclComm_t` handle. The caller is responsible for eventually
/// calling `ncclCommDestroy` on this handle.
///
/// # Safety
/// The returned communicator must be properly destroyed when no longer needed.
pub fn init_communicator(&self, rank: i32) -> Result<ncclComm_t> {
anyhow::ensure!(
rank >= 0 && rank < self.world_size,
"Invalid rank {}: must be in range [0, {})",
rank,
self.world_size
);
// CudaRC doesn't seem to have any nice bindings to the NCCL config.
// We have to manually create it the same way the NCCL C++ macros do.
let mut config: ncclConfig_t;
let max_ctas = std::env::var("DYN_KVBM_NCCL_MAX_CTAS")
.ok()
.and_then(|val| val.parse::<i32>().ok())
.unwrap_or(8);
config = ncclConfig_t {
size: std::mem::size_of::<ncclConfig_t>(),
magic: 0xcafebeef, // Required Magic Number
version: 22800, // NOTE: This needs to be updated whenever we update the NCCL version.
blocking: 1,
cgaClusterSize: i32::MIN,
minCTAs: 1,
maxCTAs: max_ctas,
netName: std::ptr::null_mut(),
splitShare: i32::MIN,
trafficClass: i32::MIN,
commName: std::ptr::null_mut(),
collnetEnable: 0,
CTAPolicy: i32::MIN,
shrinkShare: i32::MIN,
nvlsCTAs: i32::MIN,
nChannelsPerNetPeer: i32::MIN,
nvlinkCentricSched: i32::MIN,
};
let mut comm: ncclComm_t = std::ptr::null_mut();
tracing::debug!(
"Calling ncclCommInitRank: rank={}, world_size={}",
rank,
self.world_size
);
let result = unsafe {
ncclCommInitRankConfig(
&mut comm,
self.world_size,
self.unique_id,
rank,
&mut config,
)
};
check_nccl_result(
result,
&format!(
"ncclCommInitRank(rank={}, world_size={})",
rank, self.world_size
),
)?;
tracing::info!(
"NCCL communicator initialized successfully: rank={}, world_size={}",
rank,
self.world_size
);
Ok(comm)
}
/// Get the world size for this bootstrap.
pub fn world_size(&self) -> i32 {
self.world_size
}
}
/// RAII wrapper for ncclComm_t that destroys the communicator on drop.
pub struct NcclCommOwned {
comm: ncclComm_t,
}
// Safety: NCCL communicators are internally thread-safe.
// NCCL serializes operations on the same communicator.
unsafe impl Send for NcclCommOwned {}
unsafe impl Sync for NcclCommOwned {}
impl NcclCommOwned {
/// Create a new owned communicator from a raw handle.
///
/// # Safety
/// The caller must ensure that `comm` is a valid NCCL communicator
/// that has not been destroyed and is not shared elsewhere.
pub unsafe fn from_raw(comm: ncclComm_t) -> Self {
Self { comm }
}
/// Get the raw communicator handle.
pub fn as_raw(&self) -> ncclComm_t {
self.comm
}
}
impl Drop for NcclCommOwned {
fn drop(&mut self) {
if !self.comm.is_null() {
let result = unsafe { ncclCommDestroy(self.comm) };
if result != ncclResult_t::ncclSuccess {
tracing::error!("Failed to destroy NCCL communicator: {:?}", result);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize() {
let internal_bytes: [u8; 128] = [42u8; 128];
let mut unique_id = ncclUniqueId { internal: [0; 128] };
unsafe {
std::ptr::copy_nonoverlapping(
internal_bytes.as_ptr(),
unique_id.internal.as_mut_ptr().cast::<u8>(),
128,
);
}
let bootstrap = NcclBootstrap {
unique_id,
world_size: 4,
};
let bytes = bootstrap.serialize();
assert_eq!(bytes.len(), 136);
let restored = NcclBootstrap::deserialize(&bytes).unwrap();
assert_eq!(restored.world_size, 4);
let restored_bytes: &[u8; 128] =
unsafe { &*restored.unique_id.internal.as_ptr().cast::<[u8; 128]>() };
assert_eq!(*restored_bytes, [42u8; 128]);
}
#[test]
fn test_deserialize_invalid_length() {
let bytes = vec![0u8; 100]; // Wrong length
let result = NcclBootstrap::deserialize(&bytes);
assert!(result.is_err());
}
}
......@@ -27,6 +27,150 @@ use anyhow::Result;
use async_trait::async_trait;
use std::{any::Any, sync::Arc};
#[cfg(feature = "nccl")]
use cudarc::nccl::sys::ncclComm_t;
/// Transfer execution mode for distributed workers
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum TransferMode {
/// Each rank manages its own shard independently (default)
#[default]
Sharded,
/// All ranks replicate Device data via NCCL broadcast
Replicated,
}
/// Thread-safe wrapper for NCCL communicator handle.
///
/// # Safety
/// NCCL communicators are thread-safe once created. All NCCL operations using the same
/// communicator will be serialized internally by NCCL. The raw pointer is safe to send
/// between threads as long as the communicator is not destroyed while in use.
#[cfg(feature = "nccl")]
#[derive(Clone, Copy)]
pub struct NcclCommHandle(ncclComm_t);
#[cfg(feature = "nccl")]
impl NcclCommHandle {
/// Create a new NcclCommHandle from a raw ncclComm_t.
///
/// # Safety
/// The caller must ensure that:
/// - `comm` is a valid NCCL communicator
/// - The communicator will not be destroyed while this handle exists
pub unsafe fn new(comm: ncclComm_t) -> Self {
Self(comm)
}
/// Get the raw ncclComm_t handle.
pub fn as_raw(&self) -> ncclComm_t {
self.0
}
}
// Safety: NCCL communicators are thread-safe once created
#[cfg(feature = "nccl")]
unsafe impl Send for NcclCommHandle {}
#[cfg(feature = "nccl")]
unsafe impl Sync for NcclCommHandle {}
/// Inner NCCL configuration (only available with nccl feature)
#[cfg(feature = "nccl")]
#[derive(Clone, Copy)]
struct NcclConfigInner {
comm: NcclCommHandle,
rank: i32,
world_size: i32,
}
/// Transfer mode configuration for replicated transfers.
/// Always available regardless of NCCL feature - use is_enabled() to check.
#[derive(Clone, Copy, Default)]
pub struct NcclConfig {
#[cfg(feature = "nccl")]
inner: Option<NcclConfigInner>,
#[cfg(not(feature = "nccl"))]
_phantom: (),
}
impl NcclConfig {
/// Create a disabled/empty config (sharded mode)
pub fn disabled() -> Self {
Self::default()
}
/// Create an enabled config for replicated mode (only with nccl feature)
///
/// # Preconditions
/// - `0 <= rank < world_size`
/// - `world_size > 0`
///
/// # Safety
/// The caller must ensure that:
/// - `comm` is a valid NCCL communicator
/// - The communicator will not be destroyed while this config exists
#[cfg(feature = "nccl")]
pub unsafe fn enabled(comm: ncclComm_t, rank: i32, world_size: i32) -> Self {
unsafe {
assert!(
world_size > 0 && (0..world_size).contains(&rank),
"NCCL topology invariant violated: required 0 <= rank < world_size, world_size > 0; got rank={}, world_size={}",
rank,
world_size
);
Self {
inner: Some(NcclConfigInner {
comm: NcclCommHandle::new(comm),
rank,
world_size,
}),
}
}
}
/// Returns true if NCCL is enabled and configured
pub fn is_enabled(&self) -> bool {
#[cfg(feature = "nccl")]
{
self.inner.is_some()
}
#[cfg(not(feature = "nccl"))]
{
false
}
}
/// Get rank (panics if not enabled)
pub fn rank(&self) -> i32 {
#[cfg(feature = "nccl")]
{
self.inner.as_ref().expect("NCCL not enabled").rank
}
#[cfg(not(feature = "nccl"))]
{
panic!("NCCL feature not enabled")
}
}
/// Get world size (panics if not enabled)
pub fn world_size(&self) -> i32 {
#[cfg(feature = "nccl")]
{
self.inner.as_ref().expect("NCCL not enabled").world_size
}
#[cfg(not(feature = "nccl"))]
{
panic!("NCCL feature not enabled")
}
}
/// Get the NCCL communicator handle (panics if not enabled)
#[cfg(feature = "nccl")]
pub fn comm(&self) -> NcclCommHandle {
self.inner.as_ref().expect("NCCL not enabled").comm
}
}
type LocalBlock<S, M> = Block<S, locality::Local, M>;
type LocalBlockDataList<S> = Vec<LocalBlockData<S>>;
......@@ -49,6 +193,12 @@ impl ConnectorTransferBatcher {
handler: &BlockTransferHandler,
request: BlockTransferRequest,
) -> Result<()> {
// In replicated mode, execute sequentially (all ranks must participate together)
// to ensure proper NCCL collective synchronization
if handler.transfer_mode() == TransferMode::Replicated {
return handler.execute_transfer_direct(request).await;
}
let blocks = request.blocks();
let num_blocks = blocks.len();
......@@ -92,7 +242,11 @@ pub struct BlockTransferHandler {
context: Arc<TransferContext>,
scheduler_client: Option<TransferSchedulerClient>,
batcher: ConnectorTransferBatcher,
// add worker-connector scheduler client here
/// Transfer mode: sharded (default) or replicated
transfer_mode: TransferMode,
/// NCCL config (required for replicated mode)
#[cfg(feature = "nccl")]
nccl_config: NcclConfig,
}
impl BlockTransferHandler {
......@@ -102,8 +256,14 @@ impl BlockTransferHandler {
disk_blocks: Option<Vec<LocalBlock<DiskStorage, BasicMetadata>>>,
context: Arc<TransferContext>,
scheduler_client: Option<TransferSchedulerClient>,
// add worker-connector scheduler client here
nccl_config: NcclConfig,
) -> Result<Self> {
let transfer_mode = if nccl_config.is_enabled() {
TransferMode::Replicated
} else {
TransferMode::Sharded
};
Ok(Self {
device: Self::get_local_data(device_blocks),
host: Self::get_local_data(host_blocks),
......@@ -111,9 +271,17 @@ impl BlockTransferHandler {
context,
scheduler_client,
batcher: ConnectorTransferBatcher::new(),
transfer_mode,
#[cfg(feature = "nccl")]
nccl_config,
})
}
/// Returns the transfer mode (sharded or replicated)
pub fn transfer_mode(&self) -> TransferMode {
self.transfer_mode
}
fn get_local_data<S: Storage>(
blocks: Option<Vec<LocalBlock<S, BasicMetadata>>>,
) -> Option<LocalBlockDataList<S>> {
......@@ -186,8 +354,21 @@ impl BlockTransferHandler {
/// Execute transfer directly without batching (used by the batcher)
pub async fn execute_transfer_direct(&self, request: BlockTransferRequest) -> Result<()> {
match self.transfer_mode {
TransferMode::Sharded => self.execute_transfer_spmd_sharded(request).await,
#[cfg(feature = "nccl")]
TransferMode::Replicated => self.execute_transfer_spmd_replicated(request).await,
#[cfg(not(feature = "nccl"))]
TransferMode::Replicated => {
Err(anyhow::anyhow!("Replicated mode requires NCCL feature"))
}
}
}
/// Execute transfer using sharded mode (each rank manages its own shard independently)
async fn execute_transfer_spmd_sharded(&self, request: BlockTransferRequest) -> Result<()> {
tracing::debug!(
"Performing transfer of {} blocks from {:?} to {:?}",
"Performing sharded transfer of {} blocks from {:?} to {:?}",
request.blocks().len(),
request.from_pool(),
request.to_pool()
......@@ -209,6 +390,141 @@ impl BlockTransferHandler {
notify.await?;
Ok(())
}
/// Execute transfer using replicated mode (NCCL broadcast for Device blocks)
#[cfg(feature = "nccl")]
async fn execute_transfer_spmd_replicated(&self, request: BlockTransferRequest) -> Result<()> {
assert!(
self.nccl_config.is_enabled(),
"NCCL config required for replicated mode"
);
let rank = self.nccl_config.rank();
let is_rank0 = rank == 0;
let use_bcast = request.to_pool() == &Device && request.from_pool() != &Device;
if use_bcast {
tracing::info!(
"NCCL replicated transfer: {} blocks from {:?} to {:?}, rank={}, \
rank0 will load from storage then broadcast to all GPUs",
request.blocks().len(),
request.from_pool(),
request.to_pool(),
rank
);
} else {
tracing::debug!(
"Replicated transfer: {} blocks from {:?} to {:?} (rank={}, bcast={})",
request.blocks().len(),
request.from_pool(),
request.to_pool(),
rank,
use_bcast
);
}
// Device → Device: all ranks do local transfer (no broadcast)
if request.from_pool() == &Device && request.to_pool() == &Device {
return self.execute_transfer_spmd_sharded(request).await;
}
// Non-rank0 with no broadcast needed: no-op
if !is_rank0 && !use_bcast {
return Ok(());
}
// Rank 0 does the actual copy
if is_rank0 {
let notify = match (request.from_pool(), request.to_pool()) {
(Device, Host) => {
self.begin_transfer(&self.device, &self.host, request.clone())
.await
}
(Device, Disk) => {
self.begin_transfer(&self.device, &self.disk, request.clone())
.await
}
(Host, Device) => {
self.begin_transfer(&self.host, &self.device, request.clone())
.await
}
(Host, Disk) => {
self.begin_transfer(&self.host, &self.disk, request.clone())
.await
}
(Disk, Device) => {
self.begin_transfer(&self.disk, &self.device, request.clone())
.await
}
_ => {
return Err(anyhow::anyhow!("Invalid transfer type."));
}
}?;
notify.await?;
}
// Broadcast Device blocks if needed (all ranks participate)
if use_bcast {
self.broadcast_device_blocks(&request).await?;
}
Ok(())
}
/// Broadcast Device blocks to all ranks using NCCL
#[cfg(feature = "nccl")]
async fn broadcast_device_blocks(&self, request: &BlockTransferRequest) -> Result<()> {
use crate::block_manager::block::transfer::{NcclGroup, bcast_block};
let device_blocks = self
.device
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Device blocks required for broadcast"))?;
// Get raw CUstream from the CudaStream wrapper
let stream = self.context.stream().cu_stream();
let comm = self.nccl_config.comm();
// Get destination block indices (the Device blocks to broadcast)
let dst_indices: Vec<usize> = request.blocks().iter().map(|(_, to)| *to).collect();
let rank = self.nccl_config.rank();
let world_size = self.nccl_config.world_size();
tracing::info!(
"NCCL broadcast starting: rank={}/{}, num_blocks={}, block_indices={:?}",
rank,
world_size,
dst_indices.len(),
dst_indices
);
// Create NCCL group and broadcast all blocks
let group = unsafe { NcclGroup::new()? };
for &block_idx in &dst_indices {
let block = &device_blocks[block_idx];
unsafe {
bcast_block(block, 0, comm.as_raw(), stream)?;
}
}
group.end()?; // Submit the group so we can observe ncclGroupEnd errors
drop(group);
// Synchronize: wait for all NCCL operations to complete on the stream
let (tx, rx) = tokio::sync::oneshot::channel();
self.context.cuda_event(tx)?;
rx.await
.map_err(|_| anyhow::anyhow!("CUDA event channel closed"))?;
tracing::info!(
"NCCL broadcast completed: rank={}/{}, num_blocks={}",
rank,
world_size,
dst_indices.len()
);
Ok(())
}
}
#[async_trait]
......
......@@ -112,7 +112,25 @@ async fn perform_allocation_and_build_handler(
device_id: usize,
scheduler_client: Option<TransferSchedulerClient>,
) -> anyhow::Result<BlockTransferHandler> {
let agent = build_agent(worker_id, leader_meta.num_disk_blocks > 0)?;
// Determine if this rank should allocate G2/G3 (host/disk)
// - Sharded mode (rank=None): all ranks allocate
// - Replicated mode (rank=Some(r)): only rank 0 allocates
let should_allocate_offload = match worker_config.rank {
None => true, // Sharded mode: all ranks allocate
Some(0) => true, // Replicated mode rank 0: allocate
Some(_) => false, // Replicated mode non-rank0: skip
};
if !should_allocate_offload {
tracing::info!(
"Rank {} skipping host/disk allocation (replicated mode)",
worker_config.rank.unwrap_or(-1)
);
}
// Only create NIXL agent if we need disk blocks AND we should allocate
let need_disk = should_allocate_offload && leader_meta.num_disk_blocks > 0;
let agent = build_agent(worker_id, need_disk)?;
let pool_config = PoolConfig {
enable_pool: true,
max_concurrent_transfers: max_concurrent_transfers(),
......@@ -139,17 +157,17 @@ async fn perform_allocation_and_build_handler(
})?,
);
// device
// device - always allocated on all ranks
let device_blocks = Some(KvbmWorker::make_layout::<_, BasicMetadata>(
device_layout,
transfer_context.nixl_agent().as_ref(),
0,
worker_id,
)?);
// host
let host_blocks = if leader_meta.num_host_blocks > 0 {
let host_allocator = Arc::new(PinnedAllocator::new(device_id)?);
// host (G2) - only allocated if should_allocate_offload
let host_blocks = if should_allocate_offload && leader_meta.num_host_blocks > 0 {
let host_allocator = Arc::new(PinnedAllocator::default());
let host_layout = layout_builder
.num_blocks(leader_meta.num_host_blocks)
.build()?
......@@ -163,8 +181,8 @@ async fn perform_allocation_and_build_handler(
} else {
None
};
// disk
let disk_blocks = if leader_meta.num_disk_blocks > 0 {
// disk (G3) - only allocated if should_allocate_offload
let disk_blocks = if should_allocate_offload && leader_meta.num_disk_blocks > 0 {
let disk_allocator = Arc::new(DiskAllocator::from_env()?);
let disk_layout = layout_builder
.num_blocks(leader_meta.num_disk_blocks)
......@@ -186,6 +204,7 @@ async fn perform_allocation_and_build_handler(
disk_blocks,
transfer_context,
scheduler_client,
worker_config.nccl_config,
)?;
Ok(handler)
}
......@@ -411,6 +430,14 @@ pub struct KvbmWorkerConfig {
#[builder(default = "String::from(\"tcp://127.0.0.1:56002\")")]
leader_ack_url: String,
/// Rank for replicated mode (None = sharded mode)
#[builder(default = "None")]
rank: Option<i32>,
/// NCCL configuration for replicated mode
#[builder(default = "transfer::NcclConfig::disabled()")]
nccl_config: transfer::NcclConfig,
}
impl KvbmWorkerConfig {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment