"test/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "6a72e8fc7beb77883f7d56e3c24443f957f3ea46"
Commit bd21f14f authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Fix] Fix oom bug (#307)

parent e120838b
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
"enable_cfg": false, "enable_cfg": false,
"cpu_offload": true, "cpu_offload": true,
"offload_granularity": "model", "offload_granularity": "model",
"t5_cpu_offload": false,
"vae_cpu_offload": false,
"use_image_encoder": false, "use_image_encoder": false,
"boundary_step_index": 2, "boundary_step_index": 2,
"denoising_step_list": [1000, 750, 500, 250], "denoising_step_list": [1000, 750, 500, 250],
......
...@@ -13,5 +13,7 @@ ...@@ -13,5 +13,7 @@
"enable_cfg": true, "enable_cfg": true,
"cpu_offload": true, "cpu_offload": true,
"offload_granularity": "model", "offload_granularity": "model",
"t5_cpu_offload": false,
"vae_cpu_offload": false,
"boundary": 0.875 "boundary": 0.875
} }
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
"enable_cfg": false, "enable_cfg": false,
"cpu_offload": true, "cpu_offload": true,
"offload_granularity": "model", "offload_granularity": "model",
"t5_cpu_offload": false,
"vae_cpu_offload": false,
"boundary_step_index": 2, "boundary_step_index": 2,
"denoising_step_list": [1000, 750, 500, 250], "denoising_step_list": [1000, 750, 500, 250],
"lora_configs": [ "lora_configs": [
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
"enable_cfg": true, "enable_cfg": true,
"cpu_offload": false, "cpu_offload": false,
"offload_granularity": "model", "offload_granularity": "model",
"t5_cpu_offload": false,
"vae_cpu_offload": false,
"fps": 24, "fps": 24,
"use_image_encoder": false "use_image_encoder": false
} }
...@@ -17,5 +17,7 @@ ...@@ -17,5 +17,7 @@
"use_image_encoder": false, "use_image_encoder": false,
"cpu_offload": true, "cpu_offload": true,
"offload_granularity": "model", "offload_granularity": "model",
"t5_cpu_offload": false,
"vae_cpu_offload": false,
"vae_offload_cache": true "vae_offload_cache": true
} }
...@@ -15,5 +15,7 @@ ...@@ -15,5 +15,7 @@
"enable_cfg": true, "enable_cfg": true,
"cpu_offload": false, "cpu_offload": false,
"offload_granularity": "model", "offload_granularity": "model",
"t5_cpu_offload": false,
"vae_cpu_offload": false,
"fps": 24 "fps": 24
} }
...@@ -16,5 +16,7 @@ ...@@ -16,5 +16,7 @@
"fps": 24, "fps": 24,
"cpu_offload": true, "cpu_offload": true,
"offload_granularity": "model", "offload_granularity": "model",
"t5_cpu_offload": false,
"vae_cpu_offload": false,
"vae_offload_cache": true "vae_offload_cache": true
} }
...@@ -70,18 +70,11 @@ class MMWeightTemplate(metaclass=ABCMeta): ...@@ -70,18 +70,11 @@ class MMWeightTemplate(metaclass=ABCMeta):
self.bias = self.bias.cuda(non_blocking=non_blocking) self.bias = self.bias.cuda(non_blocking=non_blocking)
def to_cpu(self, non_blocking=False): def to_cpu(self, non_blocking=False):
if hasattr(self, "pinned_weight"): self.weight = self.weight.to("cpu", non_blocking=non_blocking)
self.weight = self.pinned_weight.copy_(self.weight, non_blocking=non_blocking).cpu() if hasattr(self, "weight_scale"):
if hasattr(self, "weight_scale_name"): self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
self.weight_scale = self.pinned_weight_scale.copy_(self.weight_scale, non_blocking=non_blocking).cpu() if hasattr(self, "bias") and self.bias is not None:
if self.bias is not None: self.bias = self.bias.to("cpu", non_blocking=non_blocking)
self.bias = self.pinned_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
else:
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
if hasattr(self, "weight_scale"):
self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
if hasattr(self, "bias") and self.bias is not None:
self.bias = self.bias.to("cpu", non_blocking=non_blocking)
@MM_WEIGHT_REGISTER("Default") @MM_WEIGHT_REGISTER("Default")
...@@ -90,10 +83,20 @@ class MMWeight(MMWeightTemplate): ...@@ -90,10 +83,20 @@ class MMWeight(MMWeightTemplate):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file) super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
def load(self, weight_dict): def load(self, weight_dict):
self.weight = weight_dict[self.weight_name].t() device = weight_dict[self.weight_name].device
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None weight_shape = weight_dict[self.weight_name].t().shape
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) if self.bias is not None else None weight_dtype = weight_dict[self.weight_name].dtype
self.weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype).to(device)
self.weight = self.weight.copy_(weight_dict[self.weight_name].t())
if self.bias_name is not None:
bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype).to(device)
self.bias = self.bias.copy_(weight_dict[self.bias_name])
else:
self.bias = None
def _calculate_size(self): def _calculate_size(self):
if self.bias is not None: if self.bias is not None:
...@@ -166,7 +169,6 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -166,7 +169,6 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.load_func(weight_dict) self.load_func(weight_dict)
if self.weight_need_transpose: if self.weight_need_transpose:
self.weight = self.weight.t() self.weight = self.weight.t()
self.pinned_weight = self.pinned_weight.t()
def clear(self): def clear(self):
attrs = ["weight", "weight_scale", "bias", "pinned_weight", "pinned_weight_scale", "pinned_bias"] attrs = ["weight", "weight_scale", "bias", "pinned_weight", "pinned_weight_scale", "pinned_bias"]
...@@ -182,11 +184,24 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -182,11 +184,24 @@ class MMWeightQuantTemplate(MMWeightTemplate):
return self.weight.numel() * self.weight.element_size() + self.weight_scale.numel() * self.weight_scale.element_size() return self.weight.numel() * self.weight.element_size() + self.weight_scale.numel() * self.weight_scale.element_size()
def load_quantized(self, weight_dict): def load_quantized(self, weight_dict):
self.weight = weight_dict[self.weight_name] device = weight_dict[self.weight_name].device
self.weight_scale = weight_dict[self.weight_scale_name].float() weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype).to(device)
self.weight = self.weight.copy_(weight_dict[self.weight_name])
weight_scale_shape = weight_dict[self.weight_scale_name].shape
weight_scale_dtype = torch.float
self.weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype).to(device)
self.weight_scale = self.weight_scale.copy_(weight_dict[self.weight_scale_name])
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype) if self.bias_name is not None:
self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype) bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype).to(device)
self.bias = self.bias.copy_(weight_dict[self.bias_name])
else:
self.bias = None
def load_fp8_perchannel_sym(self, weight_dict): def load_fp8_perchannel_sym(self, weight_dict):
if self.config.get("weight_auto_quant", False): if self.config.get("weight_auto_quant", False):
...@@ -195,14 +210,15 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -195,14 +210,15 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight) self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.float8_e4m3fn) self.weight = self.weight.to(torch.float8_e4m3fn)
self.weight_scale = self.weight_scale.to(torch.float32) self.weight_scale = self.weight_scale.to(torch.float32)
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
else: else:
self.load_quantized(weight_dict) self.load_quantized(weight_dict)
if self.bias_name is not None: if self.bias_name is not None:
self.bias = weight_dict[self.bias_name] device = weight_dict[self.bias_name].device
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype).to(device)
self.bias = self.bias.copy_(weight_dict[self.bias_name])
else: else:
self.bias = None self.bias = None
...@@ -213,14 +229,15 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -213,14 +229,15 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight) self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.int8) self.weight = self.weight.to(torch.int8)
self.weight_scale = self.weight_scale.to(torch.float32) self.weight_scale = self.weight_scale.to(torch.float32)
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
else: else:
self.load_quantized(weight_dict) self.load_quantized(weight_dict)
if self.bias_name is not None: if self.bias_name is not None:
self.bias = weight_dict[self.bias_name] device = weight_dict[self.bias_name].device
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype).to(device)
self.bias = self.bias.copy_(weight_dict[self.bias_name])
else: else:
self.bias = None self.bias = None
...@@ -228,14 +245,15 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -228,14 +245,15 @@ class MMWeightQuantTemplate(MMWeightTemplate):
if self.config.get("weight_auto_quant", False): if self.config.get("weight_auto_quant", False):
self.weight = weight_dict[self.weight_name] self.weight = weight_dict[self.weight_name]
self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight) self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight)
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
self.pinned_weight_scale = torch.empty(self.weight_scale.shape, pin_memory=True, dtype=self.weight_scale.dtype)
else: else:
self.load_quantized(weight_dict) self.load_quantized(weight_dict)
if self.bias_name is not None: if self.bias_name is not None:
self.bias = weight_dict[self.bias_name] device = weight_dict[self.bias_name].device
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype).to(device)
self.bias = self.bias.copy_(weight_dict[self.bias_name])
else: else:
self.bias = None self.bias = None
...@@ -713,9 +731,12 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate): ...@@ -713,9 +731,12 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
assert not self.lazy_load assert not self.lazy_load
self.load_func(weight_dict) self.load_func(weight_dict)
self.workspace = weight_dict[f"{self.weight_name}_workspace"] self.workspace = weight_dict[f"{self.weight_name}_workspace"]
if self.bias_name is not None: if self.bias_name is not None:
self.bias = weight_dict[self.bias_name] bias_shape = weight_dict[self.bias_name].shape
self.pinned_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype) bias_dtype = weight_dict[self.bias_name].dtype
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype).to(device)
self.bias = self.bias.copy_(weight_dict[self.bias_name])
else: else:
self.bias = None self.bias = None
......
...@@ -98,11 +98,8 @@ class QwenImageTransformerModel: ...@@ -98,11 +98,8 @@ class QwenImageTransformerModel:
return False return False
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer): def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
with safe_open(file_path, framework="pt") as f: with safe_open(file_path, framework="pt", device=str(self.device)) as f:
return { return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())) for key in f.keys()}
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device)
for key in f.keys()
}
def _load_ckpt(self, unified_dtype, sensitive_layer): def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_files = glob.glob(os.path.join(self.model_path, "*.safetensors")) safetensors_files = glob.glob(os.path.join(self.model_path, "*.safetensors"))
......
...@@ -39,9 +39,12 @@ class WanAudioModel(WanModel): ...@@ -39,9 +39,12 @@ class WanAudioModel(WanModel):
adapter_offload = self.config.get("cpu_offload", False) adapter_offload = self.config.get("cpu_offload", False)
self.adapter_weights_dict = load_weights(self.config.adapter_model_path, cpu_offload=adapter_offload, remove_key="audio") self.adapter_weights_dict = load_weights(self.config.adapter_model_path, cpu_offload=adapter_offload, remove_key="audio")
if not adapter_offload and not dist.is_initialized(): if not dist.is_initialized():
for key, value in self.adapter_weights_dict.items(): for key in self.adapter_weights_dict:
self.adapter_weights_dict[key] = value.cuda() # if adapter_offload:
# self.adapter_weights_dict[key] = self.adapter_weights_dict[key].pin_memory()
# else:
self.adapter_weights_dict[key] = self.adapter_weights_dict[key].pin_memory().to("cuda")
def _init_infer_class(self): def _init_infer_class(self):
super()._init_infer_class() super()._init_infer_class()
......
...@@ -128,11 +128,8 @@ class WanModel(CompiledMethodsMixin): ...@@ -128,11 +128,8 @@ class WanModel(CompiledMethodsMixin):
return False return False
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer): def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
with safe_open(file_path, framework="pt") as f: with safe_open(file_path, framework="pt", device=str(self.device)) as f:
return { return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())) for key in f.keys()}
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device)
for key in f.keys()
}
def _load_ckpt(self, unified_dtype, sensitive_layer): def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_original_ckpt", subdir="original") safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_original_ckpt", subdir="original")
...@@ -173,11 +170,11 @@ class WanModel(CompiledMethodsMixin): ...@@ -173,11 +170,11 @@ class WanModel(CompiledMethodsMixin):
torch.float, torch.float,
]: ]:
if unified_dtype or all(s not in k for s in sensitive_layer): if unified_dtype or all(s not in k for s in sensitive_layer):
weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device) weight_dict[k] = f.get_tensor(k).to(GET_DTYPE()).to(self.device)
else: else:
weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_SENSITIVE_DTYPE()).to(self.device) weight_dict[k] = f.get_tensor(k).to(GET_SENSITIVE_DTYPE()).to(self.device)
else: else:
weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device) weight_dict[k] = f.get_tensor(k).to(self.device)
return weight_dict return weight_dict
...@@ -195,11 +192,11 @@ class WanModel(CompiledMethodsMixin): ...@@ -195,11 +192,11 @@ class WanModel(CompiledMethodsMixin):
torch.float, torch.float,
]: ]:
if unified_dtype or all(s not in k for s in sensitive_layer): if unified_dtype or all(s not in k for s in sensitive_layer):
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_DTYPE()).to(self.device) pre_post_weight_dict[k] = f.get_tensor(k).to(GET_DTYPE()).to(self.device)
else: else:
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(GET_SENSITIVE_DTYPE()).to(self.device) pre_post_weight_dict[k] = f.get_tensor(k).to(GET_SENSITIVE_DTYPE()).to(self.device)
else: else:
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device) pre_post_weight_dict[k] = f.get_tensor(k).to(self.device)
return pre_post_weight_dict return pre_post_weight_dict
......
...@@ -53,20 +53,22 @@ class MultiDistillModelStruct(MultiModelStruct): ...@@ -53,20 +53,22 @@ class MultiDistillModelStruct(MultiModelStruct):
if self.scheduler.step_index < self.boundary_step_index: if self.scheduler.step_index < self.boundary_step_index:
logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}") logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
self.scheduler.sample_guide_scale = self.config.sample_guide_scale[0] self.scheduler.sample_guide_scale = self.config.sample_guide_scale[0]
if self.cur_model_index == -1: if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model":
self.to_cuda(model_index=0) if self.cur_model_index == -1:
elif self.cur_model_index == 1: # 1 -> 0 self.to_cuda(model_index=0)
self.offload_cpu(model_index=1) elif self.cur_model_index == 1: # 1 -> 0
self.to_cuda(model_index=0) self.offload_cpu(model_index=1)
self.to_cuda(model_index=0)
self.cur_model_index = 0 self.cur_model_index = 0
else: else:
logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}") logger.info(f"using - LOW - noise model at step_index {self.scheduler.step_index + 1}")
self.scheduler.sample_guide_scale = self.config.sample_guide_scale[1] self.scheduler.sample_guide_scale = self.config.sample_guide_scale[1]
if self.cur_model_index == -1: if self.config.get("cpu_offload", False) and self.config.get("offload_granularity", "block") == "model":
self.to_cuda(model_index=1) if self.cur_model_index == -1:
elif self.cur_model_index == 0: # 0 -> 1 self.to_cuda(model_index=1)
self.offload_cpu(model_index=0) elif self.cur_model_index == 0: # 0 -> 1
self.to_cuda(model_index=1) self.offload_cpu(model_index=0)
self.to_cuda(model_index=1)
self.cur_model_index = 1 self.cur_model_index = 1
......
...@@ -59,6 +59,12 @@ class WanRunner(DefaultRunner): ...@@ -59,6 +59,12 @@ class WanRunner(DefaultRunner):
def load_image_encoder(self): def load_image_encoder(self):
image_encoder = None image_encoder = None
if self.config.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True): if self.config.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True):
# offload config
clip_offload = self.config.get("clip_cpu_offload", self.config.get("cpu_offload", False))
if clip_offload:
clip_device = torch.device("cpu")
else:
clip_device = torch.device("cuda")
# quant_config # quant_config
clip_quantized = self.config.get("clip_quantized", False) clip_quantized = self.config.get("clip_quantized", False)
if clip_quantized: if clip_quantized:
...@@ -76,12 +82,12 @@ class WanRunner(DefaultRunner): ...@@ -76,12 +82,12 @@ class WanRunner(DefaultRunner):
image_encoder = CLIPModel( image_encoder = CLIPModel(
dtype=torch.float16, dtype=torch.float16,
device=self.init_device, device=clip_device,
checkpoint_path=clip_original_ckpt, checkpoint_path=clip_original_ckpt,
clip_quantized=clip_quantized, clip_quantized=clip_quantized,
clip_quantized_ckpt=clip_quantized_ckpt, clip_quantized_ckpt=clip_quantized_ckpt,
quant_scheme=clip_quant_scheme, quant_scheme=clip_quant_scheme,
cpu_offload=self.config.get("clip_cpu_offload", self.config.get("cpu_offload", False)), cpu_offload=clip_offload,
use_31_block=self.config.get("use_31_block", True), use_31_block=self.config.get("use_31_block", True),
) )
......
...@@ -52,6 +52,10 @@ def set_config(args): ...@@ -52,6 +52,10 @@ def set_config(args):
with open(os.path.join(config.model_path, "low_noise_model", "config.json"), "r") as f: with open(os.path.join(config.model_path, "low_noise_model", "config.json"), "r") as f:
model_config = json.load(f) model_config = json.load(f)
config.update(model_config) config.update(model_config)
elif os.path.exists(os.path.join(config.model_path, "distill_models", "low_noise_model", "config.json")): # 需要一个更优雅的update方法
with open(os.path.join(config.model_path, "distill_models", "low_noise_model", "config.json"), "r") as f:
model_config = json.load(f)
config.update(model_config)
elif os.path.exists(os.path.join(config.model_path, "original", "config.json")): elif os.path.exists(os.path.join(config.model_path, "original", "config.json")):
with open(os.path.join(config.model_path, "original", "config.json"), "r") as f: with open(os.path.join(config.model_path, "original", "config.json"), "r") as f:
model_config = json.load(f) model_config = json.load(f)
......
...@@ -324,49 +324,50 @@ def find_gguf_model_path(config, ckpt_config_key=None, subdir=None): ...@@ -324,49 +324,50 @@ def find_gguf_model_path(config, ckpt_config_key=None, subdir=None):
raise FileNotFoundError(f"No GGUF model files (.gguf) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.") raise FileNotFoundError(f"No GGUF model files (.gguf) found.\nPlease download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file.")
def load_safetensors(in_path: str): def load_safetensors(in_path, remove_key):
if os.path.isdir(in_path): if os.path.isdir(in_path):
return load_safetensors_from_dir(in_path) return load_safetensors_from_dir(in_path, remove_key)
elif os.path.isfile(in_path): elif os.path.isfile(in_path):
return load_safetensors_from_path(in_path) return load_safetensors_from_path(in_path, remove_key)
else: else:
raise ValueError(f"{in_path} does not exist") raise ValueError(f"{in_path} does not exist")
def load_safetensors_from_path(in_path: str): def load_safetensors_from_path(in_path, remove_key):
tensors = {} tensors = {}
with safetensors.safe_open(in_path, framework="pt", device="cpu") as f: with safetensors.safe_open(in_path, framework="pt", device="cpu") as f:
for key in f.keys(): for key in f.keys():
tensors[key] = f.get_tensor(key) if remove_key not in key:
tensors[key] = f.get_tensor(key)
return tensors return tensors
def load_safetensors_from_dir(in_dir: str): def load_safetensors_from_dir(in_dir, remove_key):
tensors = {} tensors = {}
safetensors = os.listdir(in_dir) safetensors = os.listdir(in_dir)
safetensors = [f for f in safetensors if f.endswith(".safetensors")] safetensors = [f for f in safetensors if f.endswith(".safetensors")]
for f in safetensors: for f in safetensors:
tensors.update(load_safetensors_from_path(os.path.join(in_dir, f))) tensors.update(load_safetensors_from_path(os.path.join(in_dir, f), remove_key))
return tensors return tensors
def load_pt_safetensors(in_path: str): def load_pt_safetensors(in_path, remove_key):
ext = os.path.splitext(in_path)[-1] ext = os.path.splitext(in_path)[-1]
if ext in (".pt", ".pth", ".tar"): if ext in (".pt", ".pth", ".tar"):
state_dict = torch.load(in_path, map_location="cpu", weights_only=True) state_dict = torch.load(in_path, map_location="cpu", weights_only=True)
for key in list(state_dict.keys()):
if remove_key and remove_key in key:
state_dict.pop(key)
else: else:
state_dict = load_safetensors(in_path) state_dict = load_safetensors(in_path, remove_key)
return state_dict return state_dict
def load_weights(checkpoint_path, cpu_offload=False, remove_key=None): def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
if not dist.is_initialized(): if not dist.is_initialized():
# Single GPU mode # Single GPU mode
cpu_weight_dict = load_pt_safetensors(checkpoint_path)
for key in list(cpu_weight_dict.keys()):
if remove_key and remove_key in key:
cpu_weight_dict.pop(key)
logger.info(f"Loading weights from {checkpoint_path}") logger.info(f"Loading weights from {checkpoint_path}")
cpu_weight_dict = load_pt_safetensors(checkpoint_path, remove_key)
return cpu_weight_dict return cpu_weight_dict
# Multi-GPU mode # Multi-GPU mode
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment