Commit bd21f14f authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Fix] Fix oom bug (#307)

parent e120838b
......@@ -13,6 +13,8 @@
"enable_cfg": false,
"cpu_offload": true,
"offload_granularity": "model",
"t5_cpu_offload": false,
"vae_cpu_offload": false,
"use_image_encoder": false,
"boundary_step_index": 2,
"denoising_step_list": [1000, 750, 500, 250],
......
......@@ -13,5 +13,7 @@
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "model",
"t5_cpu_offload": false,
"vae_cpu_offload": false,
"boundary": 0.875
}
......@@ -13,6 +13,8 @@
"enable_cfg": false,
"cpu_offload": true,
"offload_granularity": "model",
"t5_cpu_offload": false,
"vae_cpu_offload": false,
"boundary_step_index": 2,
"denoising_step_list": [1000, 750, 500, 250],
"lora_configs": [
......
......@@ -15,6 +15,8 @@
"enable_cfg": true,
"cpu_offload": false,
"offload_granularity": "model",
"t5_cpu_offload": false,
"vae_cpu_offload": false,
"fps": 24,
"use_image_encoder": false
}
......@@ -17,5 +17,7 @@
"use_image_encoder": false,
"cpu_offload": true,
"offload_granularity": "model",
"t5_cpu_offload": false,
"vae_cpu_offload": false,
"vae_offload_cache": true
}
......@@ -15,5 +15,7 @@
"enable_cfg": true,
"cpu_offload": false,
"offload_granularity": "model",
"t5_cpu_offload": false,
"vae_cpu_offload": false,
"fps": 24
}
......@@ -16,5 +16,7 @@
"fps": 24,
"cpu_offload": true,
"offload_granularity": "model",
"t5_cpu_offload": false,
"vae_cpu_offload": false,
"vae_offload_cache": true
}
......@@ -70,18 +70,11 @@ class MMWeightTemplate(metaclass=ABCMeta):
self.bias = self.bias.cuda(non_blocking=non_blocking)
def to_cpu(self, non_blocking=False):
if hasattr(self, "pinned_weight"):
self.weight = self.pinned_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
if hasattr(self, "weight_scale_name"):
self.weight_scale = self.pinned_weight_scale.copy_(self.weight_scale, non_blocking=non_blocking).cpu()
if self.bias is not None:
self.bias = self.pinned_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
else:
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
if hasattr(self, "weight_scale"):
self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
if hasattr(self, "bias") and self.bias is not None:
self.bias = self.bias.to("cpu", non_blocking=non_blocking)
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
if hasattr(self, "weight_scale"):
self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
if hasattr(self, "bias") and self.bias is not None:
self.bias = self.bias.to("cpu", non_blocking=non_blocking)
@MM_WEIGHT_REGISTER("Default")
......@@ -90,10 +83,20 @@ class MMWeight(MMWeightTemplate):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
def load(self, weight_dict):
self.weight = weight_dict[self.weight_name].t()
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) if self.bias is not None else None
device = weight_dict[self.weight_name].device
weight_shape = weight_dict[self.weight_name].t().shape
weight_dtype = weight_dict[self.weight_name].dtype
self.weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype).to(device)
self.weight = self.weight.copy_(weight_dict[self.weight_name].t())
if self.bias_name is not None:
bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype).to(device)
self.bias = self.bias.copy_(weight_dict[self.bias_name])
else:
self.bias = None
def _calculate_size(self):
if self.bias is not None:
......@@ -166,7 +169,6 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.load_func(weight_dict)
if self.weight_need_transpose:
self.weight = self.weight.t()
self.pinned_weight = self.pinned_weight.t()
def clear(self):
attrs = ["weight", "weight_scale", "bias", "pinned_weight", "pinned_weight_scale", "pinned_bias"]
......@@ -182,11 +184,24 @@ class MMWeightQuantTemplate(MMWeightTemplate):
return self.weight.numel() * self.weight.element_size() + self.weight_scale.numel() * self.weight_scale.element_size()
def load_quantized(self, weight_dict):
self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name].float()
device = weight_dict[self.weight_name].device
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype).to(device)
self.weight = self.weight.copy_(weight_dict[self.weight_name])
weight_scale_shape = weight_dict[self.weight_scale_name].shape
weight_scale_dtype = torch.float
self.weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype).to(device)
self.weight_scale = self.weight_scale.copy_(weight_dict[self.weight_scale_name])
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
if self.bias_name is not None:
bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype).to(device)
self.bias = self.bias.copy_(weight_dict[self.bias_name])
else:
self.bias = None
def load_fp8_perchannel_sym(self, weight_dict):
if self.config.get("weight_auto_quant", False):
......@@ -195,14 +210,15 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.float8_e4m3fn)
self.weight_scale = self.weight_scale.to(torch.float32)
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
else:
self.load_quantized(weight_dict)
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
device = weight_dict[self.bias_name].device
bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype).to(device)
self.bias = self.bias.copy_(weight_dict[self.bias_name])
else:
self.bias = None
......@@ -213,14 +229,15 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.int8)
self.weight_scale = self.weight_scale.to(torch.float32)
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
else:
self.load_quantized(weight_dict)
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
device = weight_dict[self.bias_name].device
bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype).to(device)
self.bias = self.bias.copy_(weight_dict[self.bias_name])
else:
self.bias = None
......@@ -228,14 +245,15 @@ class MMWeightQuantTemplate(MMWeightTemplate):
if self.config.get("weight_auto_quant", False):
self.weight = weight_dict[self.weight_name]
self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight)
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
else:
self.load_quantized(weight_dict)
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
device = weight_dict[self.bias_name].device
bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype).to(device)
self.bias = self.bias.copy_(weight_dict[self.bias_name])
else:
self.bias = None
......@@ -713,9 +731,12 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
assert not self.lazy_load
self.load_func(weight_dict)
self.workspace = weight_dict[f"{self.weight_name}_workspace"]
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype).to(device)
self.bias = self.bias.copy_(weight_dict[self.bias_name])
else:
self.bias = None
......
......@@ -98,11 +98,8 @@ class QwenImageTransformerModel:
return False
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
with safe_open(file_path, framework="pt") as f:
return {
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device)
for key in f.keys()
}
with safe_open(file_path, framework="pt", device=str(self.device)) as f:
return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())) for key in f.keys()}
def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_files = glob.glob(os.path.join(self.model_path, "*.safetensors"))
......
......@@ -39,9 +39,12 @@ class WanAudioModel(WanModel):
adapter_offload = self.config.get("cpu_offload", False)
self.adapter_weights_dict = load_weights(self.config.adapter_model_path, cpu_offload=adapter_offload, remove_key="audio")
if not adapter_offload and not dist.is_initialized():
for key, value in self.adapter_weights_dict.items():
self.adapter_weights_dict[key] = value.cuda()
if not dist.is_initialized():
for key in self.adapter_weights_dict:
# if adapter_offload:
# self.adapter_weights_dict[key] = self.adapter_weights_dict[key].pin_memory()
# else:
self.adapter_weights_dict[key] = self.adapter_weights_dict[key].pin_memory().to("cuda")
def _init_infer_class(self):
super()._init_infer_class()
......
......@@ -128,11 +128,8 @@ class WanModel(CompiledMethodsMixin):
return False
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
with safe_open(file_path, framework="pt") as f:
return {
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device)
for key in f.keys()
}
with safe_open(file_path, framework="pt", device=str(self.device)) as f:
return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())) for key in f.keys()}
def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_original_ckpt", subdir="original")
......@@ -173,11 +170,11 @@ class WanModel(CompiledMethodsMixin):
torch.float,
]:
if unified_dtype or all(s not in k for s in sensitive_layer):
weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device)
weight_dict[k] = f.get_tensor(k).to(GET_DTYPE()).to(self.device)
else:
weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_SENSITIVE_DTYPE()).to(self.device)
weight_dict[k] = f.get_tensor(k).to(GET_SENSITIVE_DTYPE()).to(self.device)
else:
weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
weight_dict[k] = f.get_tensor(k).to(self.device)
return weight_dict
......@@ -195,11 +192,11 @@ class WanModel(CompiledMethodsMixin):
torch.float,
]:
if unified_dtype or all(s not in k for s in sensitive_layer):
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device)
pre_post_weight_dict[k] = f.get_tensor(k).to(GET_DTYPE()).to(self.device)
else:
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_SENSITIVE_DTYPE()).to(self.device)
pre_post_weight_dict[k] = f.get_tensor(k).to(GET_SENSITIVE_DTYPE()).to(self.device)
else:
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
pre_post_weight_dict[k] = f.get_tensor(k).to(self.device)
return pre_post_weight_dict
......
......@@ -53,20 +53,22 @@ class MultiDistillModelStruct(MultiModelStruct):
if self.scheduler.step_index < self.boundary_step_index:
logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
self.scheduler.sample_guide_scale = self.config.sample_guide_scale[0]
if self.cur_model_index == -1:
self.to_cuda(model_index=0)
elif self.cur_model_index == 1: # 1 -> 0
self.offload_cpu(model_index=1)
self.to_cuda(model_index=0)
if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model":
if self.cur_model_index == -1:
self.to_cuda(model_index=0)
elif self.cur_model_index == 1: # 1 -> 0
self.offload_cpu(model_index=1)
self.to_cuda(model_index=0)
self.cur_model_index = 0
else:
logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}")
self.scheduler.sample_guide_scale = self.config.sample_guide_scale[1]
if self.cur_model_index == -1:
self.to_cuda(model_index=1)
elif self.cur_model_index == 0: # 0 -> 1
self.offload_cpu(model_index=0)
self.to_cuda(model_index=1)
if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model":
if self.cur_model_index == -1:
self.to_cuda(model_index=1)
elif self.cur_model_index == 0: # 0 -> 1
self.offload_cpu(model_index=0)
self.to_cuda(model_index=1)
self.cur_model_index = 1
......
......@@ -59,6 +59,12 @@ class WanRunner(DefaultRunner):
def load_image_encoder(self):
image_encoder = None
if self.config.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True):
# offload config
clip_offload = self.config.get("clip_cpu_offload", self.config.get("cpu_offload", False))
if clip_offload:
clip_device = torch.device("cpu")
else:
clip_device = torch.device("cuda")
# quant_config
clip_quantized = self.config.get("clip_quantized", False)
if clip_quantized:
......@@ -76,12 +82,12 @@ class WanRunner(DefaultRunner):
image_encoder = CLIPModel(
dtype=torch.float16,
device=self.init_device,
device=clip_device,
checkpoint_path=clip_original_ckpt,
clip_quantized=clip_quantized,
clip_quantized_ckpt=clip_quantized_ckpt,
quant_scheme=clip_quant_scheme,
cpu_offload=self.config.get("clip_cpu_offload", self.config.get("cpu_offload", False)),
cpu_offload=clip_offload,
use_31_block=self.config.get("use_31_block", True),
)
......
......@@ -52,6 +52,10 @@ def set_config(args):
with open(os.path.join(config.model_path, "low_noise_model", "config.json"), "r") as f:
model_config = json.load(f)
config.update(model_config)
elif os.path.exists(os.path.join(config.model_path, "distill_models", "low_noise_model", "config.json")): # 需要一个更优雅的update方法
with open(os.path.join(config.model_path, "distill_models", "low_noise_model", "config.json"), "r") as f:
model_config = json.load(f)
config.update(model_config)
elif os.path.exists(os.path.join(config.model_path, "original", "config.json")):
with open(os.path.join(config.model_path, "original", "config.json"), "r") as f:
model_config = json.load(f)
......
......@@ -324,49 +324,50 @@ def find_gguf_model_path(config, ckpt_config_key=None, subdir=None):
raise FileNotFoundError(f"No GGUF model files (.gguf) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
def load_safetensors(in_path: str):
def load_safetensors(in_path, remove_key):
if os.path.isdir(in_path):
return load_safetensors_from_dir(in_path)
return load_safetensors_from_dir(in_path, remove_key)
elif os.path.isfile(in_path):
return load_safetensors_from_path(in_path)
return load_safetensors_from_path(in_path, remove_key)
else:
raise ValueError(f"{in_path} does not exist")
def load_safetensors_from_path(in_path: str):
def load_safetensors_from_path(in_path, remove_key):
tensors = {}
with safetensors.safe_open(in_path, framework="pt", device="cpu") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
if remove_key not in key:
tensors[key] = f.get_tensor(key)
return tensors
def load_safetensors_from_dir(in_dir: str):
def load_safetensors_from_dir(in_dir, remove_key):
tensors = {}
safetensors = os.listdir(in_dir)
safetensors = [f for f in safetensors if f.endswith(".safetensors")]
for f in safetensors:
tensors.update(load_safetensors_from_path(os.path.join(in_dir, f)))
tensors.update(load_safetensors_from_path(os.path.join(in_dir, f), remove_key))
return tensors
def load_pt_safetensors(in_path: str):
def load_pt_safetensors(in_path, remove_key):
ext = os.path.splitext(in_path)[-1]
if ext in (".pt", ".pth", ".tar"):
state_dict = torch.load(in_path, map_location="cpu", weights_only=True)
for key in list(state_dict.keys()):
if remove_key and remove_key in key:
state_dict.pop(key)
else:
state_dict = load_safetensors(in_path)
state_dict = load_safetensors(in_path, remove_key)
return state_dict
def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
if not dist.is_initialized():
# Single GPU mode
cpu_weight_dict = load_pt_safetensors(checkpoint_path)
for key in list(cpu_weight_dict.keys()):
if remove_key and remove_key in key:
cpu_weight_dict.pop(key)
logger.info(f"Loading weights from {checkpoint_path}")
cpu_weight_dict = load_pt_safetensors(checkpoint_path, remove_key)
return cpu_weight_dict
# Multi-GPU mode
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment