Commit 3b896f9c authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Fix] Fix distribute load model bug (#315)

parent f085ede3
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import torch import torch
import torch.distributed as dist
from loguru import logger from loguru import logger
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
...@@ -101,6 +102,7 @@ class MMWeight(MMWeightTemplate): ...@@ -101,6 +102,7 @@ class MMWeight(MMWeightTemplate):
self.bias.copy_(weight_dict[self.bias_name]) self.bias.copy_(weight_dict[self.bias_name])
else: else:
self.bias = None self.bias = None
del weight_dict[self.weight_name]
else: else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported") raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
...@@ -203,6 +205,8 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -203,6 +205,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
weight_scale_dtype = torch.float weight_scale_dtype = torch.float
self.weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype) self.weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
self.weight_scale.copy_(weight_dict[self.weight_scale_name]) self.weight_scale.copy_(weight_dict[self.weight_scale_name])
if dist.is_initialized():
del weight_dict[self.weight_name]
else: else:
raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported") raise ValueError(f"Unsupported device type: {device.type}, only 'cpu' and 'cuda' are supported")
......
...@@ -122,14 +122,21 @@ class WanModel(CompiledMethodsMixin): ...@@ -122,14 +122,21 @@ class WanModel(CompiledMethodsMixin):
# Single GPU mode # Single GPU mode
return True return True
elif dist.is_initialized(): elif dist.is_initialized():
# Multi-GPU mode, only rank 0 loads if self.config.get("load_from_rank0", False):
if dist.get_rank() == 0: # Multi-GPU mode, only rank 0 loads
logger.info(f"Loading weights from {self.model_path}") if dist.get_rank() == 0:
logger.info(f"Loading weights from {self.model_path}")
return True
else:
return True return True
return False return False
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer): def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
with safe_open(file_path, framework="pt", device=str(self.device)) as f: if self.device.type == "cuda" and dist.is_initialized():
device = torch.device("cuda:{}".format(dist.get_rank()))
else:
device = self.device
with safe_open(file_path, framework="pt", device=str(device)) as f:
return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())) for key in f.keys()} return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())) for key in f.keys()}
def _load_ckpt(self, unified_dtype, sensitive_layer): def _load_ckpt(self, unified_dtype, sensitive_layer):
...@@ -147,8 +154,6 @@ class WanModel(CompiledMethodsMixin): ...@@ -147,8 +154,6 @@ class WanModel(CompiledMethodsMixin):
def _load_quant_ckpt(self, unified_dtype, sensitive_layer): def _load_quant_ckpt(self, unified_dtype, sensitive_layer):
ckpt_path = self.dit_quantized_ckpt ckpt_path = self.dit_quantized_ckpt
logger.info(f"Loading quant dit model from {ckpt_path}")
index_files = [f for f in os.listdir(ckpt_path) if f.endswith(".index.json")] index_files = [f for f in os.listdir(ckpt_path) if f.endswith(".index.json")]
if not index_files: if not index_files:
raise FileNotFoundError(f"No *.index.json found in {ckpt_path}") raise FileNotFoundError(f"No *.index.json found in {ckpt_path}")
...@@ -236,8 +241,8 @@ class WanModel(CompiledMethodsMixin): ...@@ -236,8 +241,8 @@ class WanModel(CompiledMethodsMixin):
else: else:
weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer) weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer)
if self.config.get("device_mesh") is not None: if self.config.get("device_mesh") is not None and self.config.get("load_from_rank0", False):
weight_dict = self._load_weights_distribute(weight_dict, is_weight_loader) weight_dict = self._load_weights_from_rank0(weight_dict, is_weight_loader)
if hasattr(self, "adapter_weights_dict"): if hasattr(self, "adapter_weights_dict"):
weight_dict.update(self.adapter_weights_dict) weight_dict.update(self.adapter_weights_dict)
...@@ -258,7 +263,8 @@ class WanModel(CompiledMethodsMixin): ...@@ -258,7 +263,8 @@ class WanModel(CompiledMethodsMixin):
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
def _load_weights_distribute(self, weight_dict, is_weight_loader): def _load_weights_from_rank0(self, weight_dict, is_weight_loader):
logger.info("Loading distributed weights")
global_src_rank = 0 global_src_rank = 0
target_device = "cpu" if self.cpu_offload else "cuda" target_device = "cpu" if self.cpu_offload else "cuda"
...@@ -313,6 +319,7 @@ class WanModel(CompiledMethodsMixin): ...@@ -313,6 +319,7 @@ class WanModel(CompiledMethodsMixin):
tensor.copy_(tensor, non_blocking=False) tensor.copy_(tensor, non_blocking=False)
logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}") logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}")
return distributed_weight_dict return distributed_weight_dict
def _init_infer(self): def _init_infer(self):
......
...@@ -10,6 +10,7 @@ from requests.exceptions import RequestException ...@@ -10,6 +10,7 @@ from requests.exceptions import RequestException
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.generate_task_id import generate_task_id from lightx2v.utils.generate_task_id import generate_task_id
from lightx2v.utils.memory_profiler import peak_memory_decorator
from lightx2v.utils.profiler import * from lightx2v.utils.profiler import *
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
...@@ -112,6 +113,7 @@ class DefaultRunner(BaseRunner): ...@@ -112,6 +113,7 @@ class DefaultRunner(BaseRunner):
def set_progress_callback(self, callback): def set_progress_callback(self, callback):
self.progress_callback = callback self.progress_callback = callback
@peak_memory_decorator
def run_segment(self, total_steps=None): def run_segment(self, total_steps=None):
if total_steps is None: if total_steps is None:
total_steps = self.model.scheduler.infer_steps total_steps = self.model.scheduler.infer_steps
......
...@@ -363,7 +363,7 @@ def load_pt_safetensors(in_path, remove_key): ...@@ -363,7 +363,7 @@ def load_pt_safetensors(in_path, remove_key):
return state_dict return state_dict
def load_weights(checkpoint_path, cpu_offload=False, remove_key=None): def load_weights(checkpoint_path, cpu_offload=False, remove_key=None, load_from_rank0=False):
if not dist.is_initialized(): if not dist.is_initialized():
# Single GPU mode # Single GPU mode
logger.info(f"Loading weights from {checkpoint_path}") logger.info(f"Loading weights from {checkpoint_path}")
...@@ -371,10 +371,10 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None): ...@@ -371,10 +371,10 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
return cpu_weight_dict return cpu_weight_dict
# Multi-GPU mode # Multi-GPU mode
is_weight_loader = False is_weight_loader = True
current_rank = dist.get_rank() current_rank = dist.get_rank()
if current_rank == 0: if load_from_rank0 and current_rank != 0:
is_weight_loader = True is_weight_loader = False
cpu_weight_dict = {} cpu_weight_dict = {}
if is_weight_loader: if is_weight_loader:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment