Commit 52f3ffc0 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Fix] Fix oom bug (#311)

parent 3c778aee
......@@ -39,12 +39,9 @@ class WanAudioModel(WanModel):
adapter_offload = self.config.get("cpu_offload", False)
self.adapter_weights_dict = load_weights(self.config.adapter_model_path, cpu_offload=adapter_offload, remove_key="audio")
if not dist.is_initialized():
if not dist.is_initialized() and not adapter_offload:
for key in self.adapter_weights_dict:
# if adapter_offload:
# self.adapter_weights_dict[key] = self.adapter_weights_dict[key].pin_memory()
# else:
self.adapter_weights_dict[key] = self.adapter_weights_dict[key].pin_memory().to("cuda")
self.adapter_weights_dict[key] = self.adapter_weights_dict[key].cuda()
def _init_infer_class(self):
super()._init_infer_class()
......
import gc
import json
import os
......@@ -253,6 +254,10 @@ class WanModel(CompiledMethodsMixin):
self.pre_weight.load(self.original_weight_dict)
self.transformer_weights.load(self.original_weight_dict)
del self.original_weight_dict
torch.cuda.empty_cache()
gc.collect()
def _load_weights_distribute(self, weight_dict, is_weight_loader):
global_src_rank = 0
target_device = "cpu" if self.cpu_offload else "cuda"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment