Commit c2f2d263 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Fix] Fix bugs (#318)

parent ed1a937b
......@@ -364,17 +364,17 @@ def load_pt_safetensors(in_path, remove_key):
def load_weights(checkpoint_path, cpu_offload=False, remove_key=None, load_from_rank0=False):
if not dist.is_initialized():
if not dist.is_initialized() or not load_from_rank0:
# Single GPU mode
logger.info(f"Loading weights from {checkpoint_path}")
cpu_weight_dict = load_pt_safetensors(checkpoint_path, remove_key)
return cpu_weight_dict
# Multi-GPU mode
is_weight_loader = True
current_rank = dist.get_rank()
if load_from_rank0 and current_rank != 0:
is_weight_loader = False
current_rank = dist.get_rank()
if current_rank == 0:
is_weight_loader = True
cpu_weight_dict = {}
if is_weight_loader:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment