Unverified Commit 4929cd02 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Update] Update lora adapter (#336)

parent e51914a5
...@@ -43,7 +43,7 @@ class WanLoraWrapper: ...@@ -43,7 +43,7 @@ class WanLoraWrapper:
lora_weights = self._load_lora_file(self.lora_metadata[lora_name]["path"]) lora_weights = self._load_lora_file(self.lora_metadata[lora_name]["path"])
weight_dict = self.model.original_weight_dict weight_dict = self.model.original_weight_dict
self._apply_lora_weights(weight_dict, lora_weights, alpha) self._apply_lora_weights(weight_dict, lora_weights, alpha)
self.model._init_weights(weight_dict) self.model._apply_weights(weight_dict)
logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}") logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}")
del lora_weights del lora_weights
...@@ -119,7 +119,7 @@ class WanLoraWrapper: ...@@ -119,7 +119,7 @@ class WanLoraWrapper:
logger.info(f"LoRA removed, restored {restored_count} weights") logger.info(f"LoRA removed, restored {restored_count} weights")
self.model._init_weights(self.model.original_weight_dict) self.model._apply_weights(self.model.original_weight_dict)
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
......
...@@ -56,6 +56,11 @@ class WanModel(CompiledMethodsMixin): ...@@ -56,6 +56,11 @@ class WanModel(CompiledMethodsMixin):
else: else:
self.seq_p_group = None self.seq_p_group = None
if self.config.get("lora_configs") and self.config.lora_configs:
self.init_empty_model = True
else:
self.init_empty_model = False
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default" self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
...@@ -254,7 +259,14 @@ class WanModel(CompiledMethodsMixin): ...@@ -254,7 +259,14 @@ class WanModel(CompiledMethodsMixin):
# Initialize weight containers # Initialize weight containers
self.pre_weight = self.pre_weight_class(self.config) self.pre_weight = self.pre_weight_class(self.config)
self.transformer_weights = self.transformer_weight_class(self.config) self.transformer_weights = self.transformer_weight_class(self.config)
if not self.init_empty_model:
self._apply_weights()
def _apply_weights(self, weight_dict=None):
if weight_dict is not None:
self.original_weight_dict = weight_dict
del weight_dict
gc.collect()
# Load weights into containers # Load weights into containers
self.pre_weight.load(self.original_weight_dict) self.pre_weight.load(self.original_weight_dict)
self.transformer_weights.load(self.original_weight_dict) self.transformer_weights.load(self.original_weight_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment