Commit 3b896f9c authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Fix] Fix distribute load model bug (#315)

parent f085ede3
from abc import ABCMeta, abstractmethod
import torch
import torch.distributed as dist
from loguru import logger
from lightx2v.utils.envs import *
......@@ -101,6 +102,7 @@ class MMWeight(MMWeightTemplate):
self.bias.copy_(weight_dict[self.bias_name])
else:
self.bias = None
del weight_dict[self.weight_name]
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
......@@ -203,6 +205,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
weight_scale_dtype = torch.float
self.weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
self.weight_scale.copy_(weight_dict[self.weight_scale_name])
if dist.is_initialized():
del weight_dict[self.weight_name]
else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
......
......@@ -122,14 +122,21 @@ class WanModel(CompiledMethodsMixin):
# Single GPU mode
return True
elif dist.is_initialized():
# Multi-GPU mode, only rank 0 loads
if dist.get_rank() == 0:
logger.info(f"Loading weights from {self.model_path}")
if self.config.get("load_from_rank0", False):
# Multi-GPU mode, only rank 0 loads
if dist.get_rank() == 0:
logger.info(f"Loading weights from {self.model_path}")
return True
else:
return True
return False
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
with safe_open(file_path, framework="pt", device=str(self.device)) as f:
if self.device.type == "cuda" and dist.is_initialized():
device = torch.device("cuda:{}".format(dist.get_rank()))
else:
device = self.device
with safe_open(file_path, framework="pt", device=str(device)) as f:
return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())) for key in f.keys()}
def _load_ckpt(self, unified_dtype, sensitive_layer):
......@@ -147,8 +154,6 @@ class WanModel(CompiledMethodsMixin):
def _load_quant_ckpt(self, unified_dtype, sensitive_layer):
ckpt_path = self.dit_quantized_ckpt
logger.info(f"Loading quant dit model from {ckpt_path}")
index_files = [f for f in os.listdir(ckpt_path) if f.endswith(".index.json")]
if not index_files:
raise FileNotFoundError(f"No *.index.json found in {ckpt_path}")
......@@ -236,8 +241,8 @@ class WanModel(CompiledMethodsMixin):
else:
weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer)
if self.config.get("device_mesh") is not None:
weight_dict = self._load_weights_distribute(weight_dict, is_weight_loader)
if self.config.get("device_mesh") is not None and self.config.get("load_from_rank0", False):
weight_dict = self._load_weights_from_rank0(weight_dict, is_weight_loader)
if hasattr(self, "adapter_weights_dict"):
weight_dict.update(self.adapter_weights_dict)
......@@ -258,7 +263,8 @@ class WanModel(CompiledMethodsMixin):
torch.cuda.empty_cache()
gc.collect()
def _load_weights_distribute(self, weight_dict, is_weight_loader):
def _load_weights_from_rank0(self, weight_dict, is_weight_loader):
logger.info("Loading distributed weights")
global_src_rank = 0
target_device = "cpu" if self.cpu_offload else "cuda"
......@@ -313,6 +319,7 @@ class WanModel(CompiledMethodsMixin):
tensor.copy_(tensor, non_blocking=False)
logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}")
return distributed_weight_dict
def _init_infer(self):
......
......@@ -10,6 +10,7 @@ from requests.exceptions import RequestException
from lightx2v.utils.envs import *
from lightx2v.utils.generate_task_id import generate_task_id
from lightx2v.utils.memory_profiler import peak_memory_decorator
from lightx2v.utils.profiler import *
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
......@@ -112,6 +113,7 @@ class DefaultRunner(BaseRunner):
def set_progress_callback(self, callback):
self.progress_callback = callback
@peak_memory_decorator
def run_segment(self, total_steps=None):
if total_steps is None:
total_steps = self.model.scheduler.infer_steps
......
......@@ -363,7 +363,7 @@ def load_pt_safetensors(in_path, remove_key):
return state_dict
def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
def load_weights(checkpoint_path, cpu_offload=False, remove_key=None, load_from_rank0=False):
if not dist.is_initialized():
# Single GPU mode
logger.info(f"Loading weights from {checkpoint_path}")
......@@ -371,10 +371,10 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
return cpu_weight_dict
# Multi-GPU mode
is_weight_loader = False
is_weight_loader = True
current_rank = dist.get_rank()
if current_rank == 0:
is_weight_loader = True
if load_from_rank0 and current_rank != 0:
is_weight_loader = False
cpu_weight_dict = {}
if is_weight_loader:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment