Commit 4659c635 authored by helloyongyang's avatar helloyongyang
Browse files

fix ci

parent a4818f0f
...@@ -18,7 +18,7 @@ class WanScheduler4ChangingResolution(WanScheduler): ...@@ -18,7 +18,7 @@ class WanScheduler4ChangingResolution(WanScheduler):
device=self.device, device=self.device,
generator=self.generator, generator=self.generator,
) )
self.noise_original_resolution = torch.randn( self.noise_original_resolution = torch.randn(
target_shape[0], target_shape[0],
target_shape[1], target_shape[1],
...@@ -45,21 +45,17 @@ class WanScheduler4ChangingResolution(WanScheduler): ...@@ -45,21 +45,17 @@ class WanScheduler4ChangingResolution(WanScheduler):
# 2. upsample clean noise to target shape # 2. upsample clean noise to target shape
denoised_sample_5d = denoised_sample.unsqueeze(0) # (C,T,H,W) -> (1,C,T,H,W) denoised_sample_5d = denoised_sample.unsqueeze(0) # (C,T,H,W) -> (1,C,T,H,W)
clean_noise = torch.nn.functional.interpolate( clean_noise = torch.nn.functional.interpolate(denoised_sample_5d, size=(self.config.target_shape[1], self.config.target_shape[2], self.config.target_shape[3]), mode="trilinear")
denoised_sample_5d, clean_noise = clean_noise.squeeze(0) # (1,C,T,H,W) -> (C,T,H,W)
size=(self.config.target_shape[1], self.config.target_shape[2], self.config.target_shape[3]),
mode='trilinear'
)
clean_noise = clean_noise.squeeze(0) # (1,C,T,H,W) -> (C,T,H,W)
# 3. add noise to clean noise # 3. add noise to clean noise
noisy_sample = self.add_noise(clean_noise, self.noise_original_resolution, self.timesteps[self.step_index + 1]) noisy_sample = self.add_noise(clean_noise, self.noise_original_resolution, self.timesteps[self.step_index + 1])
# 4. update latents # 4. update latents
self.latents = noisy_sample self.latents = noisy_sample
# self.disable_corrector = [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37] # maybe not needed # self.disable_corrector = [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37] # maybe not needed
# 5. update timesteps using shift + 2 更激进的去噪 # 5. update timesteps using shift + 2 更激进的去噪
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift + 2) self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift + 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment