"vscode:/vscode.git/clone" did not exist on "c15689027bc9984e7e4d096e99aedbac445d5fa3"
Unverified Commit f21da849 authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files
parent 3efc43f5
...@@ -11,6 +11,7 @@ class WanPreInfer: ...@@ -11,6 +11,7 @@ class WanPreInfer:
assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0 assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0
self.config = config self.config = config
d = config["dim"] // config["num_heads"] d = config["dim"] // config["num_heads"]
self.run_device = self.config.get("run_device", "cuda")
self.clean_cuda_cache = config.get("clean_cuda_cache", False) self.clean_cuda_cache = config.get("clean_cuda_cache", False)
self.task = config["task"] self.task = config["task"]
self.device = torch.device(self.config.get("run_device", "cuda")) self.device = torch.device(self.config.get("run_device", "cuda"))
...@@ -21,7 +22,7 @@ class WanPreInfer: ...@@ -21,7 +22,7 @@ class WanPreInfer:
rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6)),
], ],
dim=1, dim=1,
).to(self.device) ).to(torch.device(self.run_device))
self.freq_dim = config["freq_dim"] self.freq_dim = config["freq_dim"]
self.dim = config["dim"] self.dim = config["dim"]
self.enable_dynamic_cfg = config.get("enable_dynamic_cfg", False) self.enable_dynamic_cfg = config.get("enable_dynamic_cfg", False)
......
...@@ -48,6 +48,7 @@ class WanModel(CompiledMethodsMixin): ...@@ -48,6 +48,7 @@ class WanModel(CompiledMethodsMixin):
super().__init__() super().__init__()
self.model_path = model_path self.model_path = model_path
self.config = config self.config = config
self.run_device = self.config.get("run_device", "cuda")
self.cpu_offload = self.config.get("cpu_offload", False) self.cpu_offload = self.config.get("cpu_offload", False)
self.offload_granularity = self.config.get("offload_granularity", "block") self.offload_granularity = self.config.get("offload_granularity", "block")
self.model_type = model_type self.model_type = model_type
......
This diff is collapsed.
This diff is collapsed.
...@@ -450,7 +450,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -450,7 +450,7 @@ class WanAudioRunner(WanRunner): # type:ignore
ref_img = img_path ref_img = img_path
else: else:
ref_img = load_image(img_path) ref_img = load_image(img_path)
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(0).to(self.init_device) ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(0).to(self.run_device)
ref_img, h, w = resize_image( ref_img, h, w = resize_image(
ref_img, ref_img,
...@@ -538,7 +538,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -538,7 +538,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def prepare_prev_latents(self, prev_video: Optional[torch.Tensor], prev_frame_length: int) -> Optional[Dict[str, torch.Tensor]]: def prepare_prev_latents(self, prev_video: Optional[torch.Tensor], prev_frame_length: int) -> Optional[Dict[str, torch.Tensor]]:
"""Prepare previous latents for conditioning""" """Prepare previous latents for conditioning"""
device = self.init_device device = self.run_device
dtype = GET_DTYPE() dtype = GET_DTYPE()
tgt_h, tgt_w = self.input_info.target_shape[0], self.input_info.target_shape[1] tgt_h, tgt_w = self.input_info.target_shape[0], self.input_info.target_shape[1]
...@@ -835,7 +835,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -835,7 +835,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def load_audio_encoder(self): def load_audio_encoder(self):
audio_encoder_path = self.config.get("audio_encoder_path", os.path.join(self.config["model_path"], "TencentGameMate-chinese-hubert-large")) audio_encoder_path = self.config.get("audio_encoder_path", os.path.join(self.config["model_path"], "TencentGameMate-chinese-hubert-large"))
audio_encoder_offload = self.config.get("audio_encoder_cpu_offload", self.config.get("cpu_offload", False)) audio_encoder_offload = self.config.get("audio_encoder_cpu_offload", self.config.get("cpu_offload", False))
model = SekoAudioEncoderModel(audio_encoder_path, self.config["audio_sr"], audio_encoder_offload, device=self.config.get("run_device", "cuda")) model = SekoAudioEncoderModel(audio_encoder_path, self.config["audio_sr"], audio_encoder_offload, run_device=self.config.get("run_device", "cuda"))
return model return model
def load_audio_adapter(self): def load_audio_adapter(self):
...@@ -843,7 +843,8 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -843,7 +843,8 @@ class WanAudioRunner(WanRunner): # type:ignore
if audio_adapter_offload: if audio_adapter_offload:
device = torch.device("cpu") device = torch.device("cpu")
else: else:
device = torch.device(self.config.get("run_device", "cuda")) device = torch.device(self.run_device)
audio_adapter = AudioAdapter( audio_adapter = AudioAdapter(
attention_head_dim=self.config["dim"] // self.config["num_heads"], attention_head_dim=self.config["dim"] // self.config["num_heads"],
num_attention_heads=self.config["num_heads"], num_attention_heads=self.config["num_heads"],
...@@ -856,7 +857,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -856,7 +857,7 @@ class WanAudioRunner(WanRunner): # type:ignore
quantized=self.config.get("adapter_quantized", False), quantized=self.config.get("adapter_quantized", False),
quant_scheme=self.config.get("adapter_quant_scheme", None), quant_scheme=self.config.get("adapter_quant_scheme", None),
cpu_offload=audio_adapter_offload, cpu_offload=audio_adapter_offload,
device=self.config.get("run_device", "cuda"), run_device=self.run_device,
) )
audio_adapter.to(device) audio_adapter.to(device)
...@@ -896,6 +897,7 @@ class Wan22AudioRunner(WanAudioRunner): ...@@ -896,6 +897,7 @@ class Wan22AudioRunner(WanAudioRunner):
vae_config = { vae_config = {
"vae_path": find_torch_model_path(self.config, "vae_path", "Wan2.2_VAE.pth"), "vae_path": find_torch_model_path(self.config, "vae_path", "Wan2.2_VAE.pth"),
"device": vae_device, "device": vae_device,
"run_device": self.run_device,
"cpu_offload": vae_offload, "cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False), "offload_cache": self.config.get("vae_offload_cache", False),
} }
...@@ -912,6 +914,7 @@ class Wan22AudioRunner(WanAudioRunner): ...@@ -912,6 +914,7 @@ class Wan22AudioRunner(WanAudioRunner):
vae_config = { vae_config = {
"vae_path": find_torch_model_path(self.config, "vae_path", "Wan2.2_VAE.pth"), "vae_path": find_torch_model_path(self.config, "vae_path", "Wan2.2_VAE.pth"),
"device": vae_device, "device": vae_device,
"run_device": self.run_device,
"cpu_offload": vae_offload, "cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False), "offload_cache": self.config.get("vae_offload_cache", False),
} }
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment