You need to sign in or sign up before continuing.
Unverified Commit 2263defc authored by Alec's avatar Alec Committed by GitHub
Browse files

fix(kvbm): set CUDA device per rank before ncclCommInitRank (#8147)


Signed-off-by: default avatarAlec Flowers <aflowers@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 37bef607
...@@ -88,13 +88,20 @@ def _create_kvbm_nccl_comm(rank: int, world_size: int): ...@@ -88,13 +88,20 @@ def _create_kvbm_nccl_comm(rank: int, world_size: int):
logger.info(f"KVBM: Rank {rank} bootstrap world_size={bootstrap.world_size()}") logger.info(f"KVBM: Rank {rank} bootstrap world_size={bootstrap.world_size()}")
# Trust the framework (TRT-LLM / MPI launcher) to have already # TRT-LLM MPI launch exposes all GPUs to each rank but manages device
# set the correct CUDA device for this rank, either via # assignment internally. torch.cuda.current_device() defaults to 0 for
# CUDA_VISIBLE_DEVICES or its own initialization. # all ranks, which causes ncclCommInitRank to fail (all ranks on same
current_device = torch.cuda.current_device() # device). Explicitly set the device to match the MPI rank.
device_count = torch.cuda.device_count()
if device_count <= 0:
raise RuntimeError(
"KVBM NCCL MLA mode requires at least one visible CUDA device."
)
device_id = rank % device_count
torch.cuda.set_device(device_id)
logger.info( logger.info(
f"KVBM: Rank {rank} on CUDA device {current_device} " f"KVBM: Rank {rank} set to CUDA device {device_id} "
f"(device_count={torch.cuda.device_count()})" f"(device_count={device_count})"
) )
logger.info(f"KVBM: Rank {rank} waiting at MPI barrier " "before ncclCommInitRank") logger.info(f"KVBM: Rank {rank} waiting at MPI barrier " "before ncclCommInitRank")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment