"vscode:/vscode.git/clone" did not exist on "9000d93789e5ee2a2ba86c21e4f0ddc16f2a9343"
Commit 834b09c4 authored by helloyongyang's avatar helloyongyang
Browse files

[Feature] Support progressive resolution

parent 3e67df1c
......@@ -12,6 +12,6 @@
"enable_cfg": true,
"cpu_offload": false,
"changing_resolution": true,
"resolution_rate": 0.75,
"changing_resolution_steps": 20
"resolution_rate": [1.0, 0.75],
"changing_resolution_steps": [5, 25]
}
......@@ -13,6 +13,6 @@
"enable_cfg": true,
"cpu_offload": false,
"changing_resolution": true,
"resolution_rate": 0.75,
"changing_resolution_steps": 25
"resolution_rate": [1.0, 0.75],
"changing_resolution_steps": [10, 35]
}
......@@ -44,8 +44,8 @@ class WanPreInfer:
if self.task == "i2v":
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
if self.config.get("changing_resolution", False) and self.scheduler.step_index > self.config.changing_resolution_steps - 1:
image_encoder = inputs["image_encoder_output"]["vae_encode_out_original_resolution"]
if self.config.get("changing_resolution", False):
image_encoder = inputs["image_encoder_output"]["vae_encode_out"][self.scheduler.changing_resolution_index]
else:
image_encoder = inputs["image_encoder_output"]["vae_encode_out"]
......
......@@ -211,12 +211,12 @@ class WanRunner(DefaultRunner):
if self.config.get("changing_resolution", False):
self.config.lat_h, self.config.lat_w = lat_h, lat_w
vae_encode_out_original_resolution = self.get_vae_encoder_output(img, lat_h, lat_w)
# get vae encode out at low resolution
lat_h, lat_w = int(self.config.lat_h * self.config.resolution_rate) // 2 * 2, int(self.config.lat_w * self.config.resolution_rate) // 2 * 2
vae_encode_out = self.get_vae_encoder_output(img, lat_h, lat_w)
return vae_encode_out, vae_encode_out_original_resolution # low resolution, original resolution
vae_encode_out_list = []
for i in range(len(self.config["resolution_rate"])):
lat_h, lat_w = int(self.config.lat_h * self.config.resolution_rate[i]) // 2 * 2, int(self.config.lat_w * self.config.resolution_rate[i]) // 2 * 2
vae_encode_out_list.append(self.get_vae_encoder_output(img, lat_h, lat_w))
vae_encode_out_list.append(self.get_vae_encoder_output(img, self.config.lat_h, self.config.lat_w))
return vae_encode_out_list
else:
self.config.lat_h, self.config.lat_w = lat_h, lat_w
vae_encode_out = self.get_vae_encoder_output(img, lat_h, lat_w)
......@@ -259,18 +259,10 @@ class WanRunner(DefaultRunner):
return vae_encode_out
def get_encoder_output_i2v(self, clip_encoder_out, vae_encode_out, text_encoder_output, img):
if self.config.get("changing_resolution", False):
image_encoder_output = {
"clip_encoder_out": clip_encoder_out,
"vae_encode_out": vae_encode_out[0],
"vae_encode_out_original_resolution": vae_encode_out[1],
}
else:
image_encoder_output = {
"clip_encoder_out": clip_encoder_out,
"vae_encode_out": vae_encode_out,
}
image_encoder_output = {
"clip_encoder_out": clip_encoder_out,
"vae_encode_out": vae_encode_out,
}
return {
"text_encoder_output": text_encoder_output,
"image_encoder_output": image_encoder_output,
......
......@@ -5,33 +5,48 @@ from lightx2v.models.schedulers.wan.scheduler import WanScheduler
class WanScheduler4ChangingResolution(WanScheduler):
def __init__(self, config):
super().__init__(config)
self.resolution_rate = config.get("resolution_rate", 0.75)
self.changing_resolution_steps = config.get("changing_resolution_steps", config.infer_steps // 2)
if "resolution_rate" not in config:
config["resolution_rate"] = [0.75]
if "changing_resolution_steps" not in config:
config["changing_resolution_steps"] = [config.infer_steps // 2]
assert len(config["resolution_rate"]) == len(config["changing_resolution_steps"])
def prepare_latents(self, target_shape, dtype=torch.float32):
self.latents = torch.randn(
target_shape[0],
target_shape[1],
int(target_shape[2] * self.resolution_rate) // 2 * 2,
int(target_shape[3] * self.resolution_rate) // 2 * 2,
dtype=dtype,
device=self.device,
generator=self.generator,
)
self.latents_list = []
for i in range(len(self.config["resolution_rate"])):
self.latents_list.append(
torch.randn(
target_shape[0],
target_shape[1],
int(target_shape[2] * self.config["resolution_rate"][i]) // 2 * 2,
int(target_shape[3] * self.config["resolution_rate"][i]) // 2 * 2,
dtype=dtype,
device=self.device,
generator=self.generator,
)
)
self.noise_original_resolution = torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=dtype,
device=self.device,
generator=self.generator,
# add original resolution latents
self.latents_list.append(
torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=dtype,
device=self.device,
generator=self.generator,
)
)
# set initial latents
self.latents = self.latents_list[0]
self.changing_resolution_index = 0
def step_post(self):
if self.step_index == self.changing_resolution_steps - 1:
if self.step_index + 1 in self.config["changing_resolution_steps"]:
self.step_post_upsample()
self.changing_resolution_index += 1
else:
super().step_post()
......@@ -45,19 +60,21 @@ class WanScheduler4ChangingResolution(WanScheduler):
# 2. upsample clean noise to target shape
denoised_sample_5d = denoised_sample.unsqueeze(0) # (C,T,H,W) -> (1,C,T,H,W)
clean_noise = torch.nn.functional.interpolate(denoised_sample_5d, size=(self.config.target_shape[1], self.config.target_shape[2], self.config.target_shape[3]), mode="trilinear")
shape_to_upsampled = self.latents_list[self.changing_resolution_index + 1].shape[1:]
clean_noise = torch.nn.functional.interpolate(denoised_sample_5d, size=shape_to_upsampled, mode="trilinear")
clean_noise = clean_noise.squeeze(0) # (1,C,T,H,W) -> (C,T,H,W)
# 3. add noise to clean noise
noisy_sample = self.add_noise(clean_noise, self.noise_original_resolution, self.timesteps[self.step_index + 1])
noisy_sample = self.add_noise(clean_noise, self.latents_list[self.changing_resolution_index + 1], self.timesteps[self.step_index + 1])
# 4. update latents
self.latents = noisy_sample
# self.disable_corrector = [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37] # maybe not needed
# 5. update timesteps using shift + 2 更激进的去噪
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift + 2)
# 5. update timesteps using shift + self.changing_resolution_index + 1 更激进的去噪
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift + self.changing_resolution_index + 1)
def add_noise(self, original_samples, noise, timesteps):
sigma = self.sigmas[self.step_index]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment