Commit 99158e75 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Fix] Fix adapter weights distribute load bug (#320)

parent 0fc763de
...@@ -40,9 +40,10 @@ class WanAudioModel(WanModel): ...@@ -40,9 +40,10 @@ class WanAudioModel(WanModel):
adapter_offload = self.config.get("cpu_offload", False) adapter_offload = self.config.get("cpu_offload", False)
load_from_rank0 = self.config.get("load_from_rank0", False) load_from_rank0 = self.config.get("load_from_rank0", False)
self.adapter_weights_dict = load_weights(self.config.adapter_model_path, cpu_offload=adapter_offload, remove_key="audio", load_from_rank0=load_from_rank0) self.adapter_weights_dict = load_weights(self.config.adapter_model_path, cpu_offload=adapter_offload, remove_key="audio", load_from_rank0=load_from_rank0)
if not dist.is_initialized() and not adapter_offload: if not adapter_offload:
for key in self.adapter_weights_dict: if not dist.is_initialized() or not load_from_rank0:
self.adapter_weights_dict[key] = self.adapter_weights_dict[key].cuda() for key in self.adapter_weights_dict:
self.adapter_weights_dict[key] = self.adapter_weights_dict[key].cuda()
def _init_infer_class(self): def _init_infer_class(self):
super()._init_infer_class() super()._init_infer_class()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment