Unverified Commit 5f277e80 authored by Gu Shiqiao's avatar Gu Shiqiao Committed by GitHub
Browse files

Set run_device in constructor before loading adapter (#517)

Refactor initialization to set run_device before loading adapter
checkpoint.
parent dea872a2
...@@ -22,8 +22,9 @@ class WanAudioModel(WanModel): ...@@ -22,8 +22,9 @@ class WanAudioModel(WanModel):
def __init__(self, model_path, config, device): def __init__(self, model_path, config, device):
self.config = config self.config = config
self.run_device = self.config.get("run_device", "cuda")
self._load_adapter_ckpt()
super().__init__(model_path, config, device) super().__init__(model_path, config, device)
self._load_adapter_ckpt() # depend on run_device
def _load_adapter_ckpt(self): def _load_adapter_ckpt(self):
if self.config.get("adapter_model_path", None) is None: if self.config.get("adapter_model_path", None) is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment