Unverified Commit d9bede03 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[BugFix] Fix `fastsafetensors` TP all procs using all GPUs (#34070)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent 22b64948
...@@ -801,8 +801,8 @@ def runai_safetensors_weights_iterator( ...@@ -801,8 +801,8 @@ def runai_safetensors_weights_iterator(
yield from tensor_iter yield from tensor_iter
def _init_loader( def _init_fastsafetensors_loader(
pg: torch.distributed.ProcessGroup, pg: "torch.distributed.ProcessGroup",
device: torch.device, device: torch.device,
f_list: list[str], f_list: list[str],
*, *,
...@@ -825,13 +825,16 @@ def fastsafetensors_weights_iterator( ...@@ -825,13 +825,16 @@ def fastsafetensors_weights_iterator(
else: else:
pg = SingleGroup() pg = SingleGroup()
device = torch.device(f"cuda:{pg.rank()}") device = torch.device(f"cuda:{current_platform.current_device()}")
weight_files_sub_lists = [ weight_files_sub_lists = [
hf_weights_files[i : i + pg.size()] hf_weights_files[i : i + pg.size()]
for i in range(0, len(hf_weights_files), pg.size()) for i in range(0, len(hf_weights_files), pg.size())
] ]
nogds = False # Use nogds=True for TP > 1 to avoid cuFileDriverOpen() which
# initializes the GDS DMA subsystem for all visible GPUs, creating
# unwanted CUDA contexts on every device.
nogds = pg.size() > 1
for f_list in tqdm( for f_list in tqdm(
weight_files_sub_lists, weight_files_sub_lists,
...@@ -839,7 +842,7 @@ def fastsafetensors_weights_iterator( ...@@ -839,7 +842,7 @@ def fastsafetensors_weights_iterator(
disable=not enable_tqdm(use_tqdm_on_load), disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT, bar_format=_BAR_FORMAT,
): ):
loader = _init_loader(pg, device, f_list, nogds=nogds) loader = _init_fastsafetensors_loader(pg, device, f_list, nogds=nogds)
try: try:
try: try:
fb = loader.copy_files_to_device() fb = loader.copy_files_to_device()
...@@ -853,7 +856,7 @@ def fastsafetensors_weights_iterator( ...@@ -853,7 +856,7 @@ def fastsafetensors_weights_iterator(
"GDS not enabled, setting `nogds=True`.\n" "GDS not enabled, setting `nogds=True`.\n"
"For more information, see: https://github.com/foundation-model-stack/fastsafetensors?tab=readme-ov-file#basic-api-usages" "For more information, see: https://github.com/foundation-model-stack/fastsafetensors?tab=readme-ov-file#basic-api-usages"
) )
loader = _init_loader(pg, device, f_list, nogds=nogds) loader = _init_fastsafetensors_loader(pg, device, f_list, nogds=nogds)
fb = loader.copy_files_to_device() fb = loader.copy_files_to_device()
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment