"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "3cbc6a5c01120b59cf38378d5f18836bcd90f6c9"
Commit c962f4ce authored by gushiqiao's avatar gushiqiao
Browse files

Fix

parent 2e3472a7
...@@ -327,8 +327,12 @@ def find_gguf_model_path(config, ckpt_config_key=None, subdir=None): ...@@ -327,8 +327,12 @@ def find_gguf_model_path(config, ckpt_config_key=None, subdir=None):
def load_weights(checkpoint_path, cpu_offload=False, remove_key=None): def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
if not dist.is_initialized(): if not dist.is_initialized():
# Single GPU mode # Single GPU mode
cpu_weight_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
for key in list(cpu_weight_dict.keys()):
if remove_key and remove_key in key:
cpu_weight_dict.pop(key)
logger.info(f"Loading weights from {checkpoint_path}") logger.info(f"Loading weights from {checkpoint_path}")
return torch.load(checkpoint_path, map_location="cpu", weights_only=True) return cpu_weight_dict
# Multi-GPU mode # Multi-GPU mode
is_weight_loader = False is_weight_loader = False
...@@ -337,14 +341,13 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None): ...@@ -337,14 +341,13 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
is_weight_loader = True is_weight_loader = True
cpu_weight_dict = {} cpu_weight_dict = {}
if is_weight_loader: # rank0在 CPU 上加载完整的权重字典 if is_weight_loader:
logger.info(f"Loading weights from {checkpoint_path}") logger.info(f"Loading weights from {checkpoint_path}")
cpu_weight_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) cpu_weight_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
for key in list(cpu_weight_dict.keys()): for key in list(cpu_weight_dict.keys()):
if remove_key and remove_key in key: if remove_key and remove_key in key:
cpu_weight_dict.pop(key) cpu_weight_dict.pop(key)
# 同步字典的结构
meta_dict = {} meta_dict = {}
if is_weight_loader: if is_weight_loader:
for key, tensor in cpu_weight_dict.items(): for key, tensor in cpu_weight_dict.items():
...@@ -352,31 +355,25 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None): ...@@ -352,31 +355,25 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
obj_list = [meta_dict] if is_weight_loader else [None] obj_list = [meta_dict] if is_weight_loader else [None]
# 获取rank0的全局 rank 用于广播
src_global_rank = 0 src_global_rank = 0
dist.broadcast_object_list(obj_list, src=src_global_rank) dist.broadcast_object_list(obj_list, src=src_global_rank)
synced_meta_dict = obj_list[0] synced_meta_dict = obj_list[0]
# 根据offload配置决定目标设备
if cpu_offload: if cpu_offload:
# Multi-GPU + offload: weights on CPU # Multi-GPU + offload: weights on CPU
target_device = "cpu" target_device = "cpu"
distributed_weight_dict = {key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) for key, meta in synced_meta_dict.items()} distributed_weight_dict = {key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) for key, meta in synced_meta_dict.items()}
# CPU分发使用普通barrier
dist.barrier() dist.barrier()
else: else:
# Multi-GPU + non-offload: weights on GPU # Multi-GPU + non-offload: weights on GPU
target_device = torch.device(f"cuda:{current_rank}") target_device = torch.device(f"cuda:{current_rank}")
distributed_weight_dict = {key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) for key, meta in synced_meta_dict.items()} distributed_weight_dict = {key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) for key, meta in synced_meta_dict.items()}
# GPU分发使用CUDA barrier
dist.barrier(device_ids=[torch.cuda.current_device()]) dist.barrier(device_ids=[torch.cuda.current_device()])
# 广播权重
for key in sorted(synced_meta_dict.keys()): for key in sorted(synced_meta_dict.keys()):
tensor_to_broadcast = distributed_weight_dict[key] tensor_to_broadcast = distributed_weight_dict[key]
if is_weight_loader: if is_weight_loader:
tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True) tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True)
# 广播到所有ranks
dist.broadcast(tensor_to_broadcast, src=src_global_rank) dist.broadcast(tensor_to_broadcast, src=src_global_rank)
if is_weight_loader: if is_weight_loader:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment