Unverified Commit f21da849 authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files
parent 3efc43f5
......@@ -11,6 +11,7 @@ class WanPreInfer:
assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0
self.config = config
d = config["dim"] // config["num_heads"]
self.run_device = self.config.get("run_device", "cuda")
self.clean_cuda_cache = config.get("clean_cuda_cache", False)
self.task = config["task"]
self.device = torch.device(self.config.get("run_device", "cuda"))
......@@ -21,7 +22,7 @@ class WanPreInfer:
rope_params(1024, 2 * (d // 6)),
],
dim=1,
).to(self.device)
).to(torch.device(self.run_device))
self.freq_dim = config["freq_dim"]
self.dim = config["dim"]
self.enable_dynamic_cfg = config.get("enable_dynamic_cfg", False)
......
......@@ -48,6 +48,7 @@ class WanModel(CompiledMethodsMixin):
super().__init__()
self.model_path = model_path
self.config = config
self.run_device = self.config.get("run_device", "cuda")
self.cpu_offload = self.config.get("cpu_offload", False)
self.offload_granularity = self.config.get("offload_granularity", "block")
self.model_type = model_type
......
This diff is collapsed.
This diff is collapsed.
......@@ -450,7 +450,7 @@ class WanAudioRunner(WanRunner): # type:ignore
ref_img = img_path
else:
ref_img = load_image(img_path)
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(0).to(self.init_device)
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(0).to(self.run_device)
ref_img, h, w = resize_image(
ref_img,
......@@ -538,7 +538,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def prepare_prev_latents(self, prev_video: Optional[torch.Tensor], prev_frame_length: int) -> Optional[Dict[str, torch.Tensor]]:
"""Prepare previous latents for conditioning"""
device = self.init_device
device = self.run_device
dtype = GET_DTYPE()
tgt_h, tgt_w = self.input_info.target_shape[0], self.input_info.target_shape[1]
......@@ -835,7 +835,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def load_audio_encoder(self):
audio_encoder_path = self.config.get("audio_encoder_path", os.path.join(self.config["model_path"], "TencentGameMate-chinese-hubert-large"))
audio_encoder_offload = self.config.get("audio_encoder_cpu_offload", self.config.get("cpu_offload", False))
model = SekoAudioEncoderModel(audio_encoder_path, self.config["audio_sr"], audio_encoder_offload, device=self.config.get("run_device", "cuda"))
model = SekoAudioEncoderModel(audio_encoder_path, self.config["audio_sr"], audio_encoder_offload, run_device=self.config.get("run_device", "cuda"))
return model
def load_audio_adapter(self):
......@@ -843,7 +843,8 @@ class WanAudioRunner(WanRunner): # type:ignore
if audio_adapter_offload:
device = torch.device("cpu")
else:
device = torch.device(self.config.get("run_device", "cuda"))
device = torch.device(self.run_device)
audio_adapter = AudioAdapter(
attention_head_dim=self.config["dim"] // self.config["num_heads"],
num_attention_heads=self.config["num_heads"],
......@@ -856,7 +857,7 @@ class WanAudioRunner(WanRunner): # type:ignore
quantized=self.config.get("adapter_quantized", False),
quant_scheme=self.config.get("adapter_quant_scheme", None),
cpu_offload=audio_adapter_offload,
device=self.config.get("run_device", "cuda"),
run_device=self.run_device,
)
audio_adapter.to(device)
......@@ -896,6 +897,7 @@ class Wan22AudioRunner(WanAudioRunner):
vae_config = {
"vae_path": find_torch_model_path(self.config, "vae_path", "Wan2.2_VAE.pth"),
"device": vae_device,
"run_device": self.run_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
}
......@@ -912,6 +914,7 @@ class Wan22AudioRunner(WanAudioRunner):
vae_config = {
"vae_path": find_torch_model_path(self.config, "vae_path", "Wan2.2_VAE.pth"),
"device": vae_device,
"run_device": self.run_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
}
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment