"app/vscode:/vscode.git/clone" did not exist on "bddfa2100f9b708c76c99167a180a05da5140e95"
Commit 2e3472a7 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Update utils.py

parent d502fab6
...@@ -375,14 +375,7 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None): ...@@ -375,14 +375,7 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
for key in sorted(synced_meta_dict.keys()): for key in sorted(synced_meta_dict.keys()):
tensor_to_broadcast = distributed_weight_dict[key] tensor_to_broadcast = distributed_weight_dict[key]
if is_weight_loader: if is_weight_loader:
# rank0将CPU权重拷贝到目标设备,准备广播 tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True)
if cpu_offload:
# CPU模式:直接复制
tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True)
else:
# GPU模式:先复制到当前GPU,再广播
tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True)
# 广播到所有ranks # 广播到所有ranks
dist.broadcast(tensor_to_broadcast, src=src_global_rank) dist.broadcast(tensor_to_broadcast, src=src_global_rank)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment