"vscode:/vscode.git/clone" did not exist on "371b1251dd7db7581377e40ea4e4626c3f83ef7a"
Unverified Commit d9bede03 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[BugFix] Fix `fastsafetensors` TP all procs using all GPUs (#34070)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent 22b64948
......@@ -801,8 +801,8 @@ def runai_safetensors_weights_iterator(
yield from tensor_iter
def _init_loader(
pg: torch.distributed.ProcessGroup,
def _init_fastsafetensors_loader(
pg: "torch.distributed.ProcessGroup",
device: torch.device,
f_list: list[str],
*,
......@@ -825,13 +825,16 @@ def fastsafetensors_weights_iterator(
else:
pg = SingleGroup()
device = torch.device(f"cuda:{pg.rank()}")
device = torch.device(f"cuda:{current_platform.current_device()}")
weight_files_sub_lists = [
hf_weights_files[i : i + pg.size()]
for i in range(0, len(hf_weights_files), pg.size())
]
nogds = False
# Use nogds=True for TP > 1 to avoid cuFileDriverOpen() which
# initializes the GDS DMA subsystem for all visible GPUs, creating
# unwanted CUDA contexts on every device.
nogds = pg.size() > 1
for f_list in tqdm(
weight_files_sub_lists,
......@@ -839,7 +842,7 @@ def fastsafetensors_weights_iterator(
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
loader = _init_loader(pg, device, f_list, nogds=nogds)
loader = _init_fastsafetensors_loader(pg, device, f_list, nogds=nogds)
try:
try:
fb = loader.copy_files_to_device()
......@@ -853,7 +856,7 @@ def fastsafetensors_weights_iterator(
"GDS not enabled, setting `nogds=True`.\n"
"For more information, see: https://github.com/foundation-model-stack/fastsafetensors?tab=readme-ov-file#basic-api-usages"
)
loader = _init_loader(pg, device, f_list, nogds=nogds)
loader = _init_fastsafetensors_loader(pg, device, f_list, nogds=nogds)
fb = loader.copy_files_to_device()
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment