"vscode:/vscode.git/clone" did not exist on "c41d6565295f039a97877cd63055760cd9e749a9"
Commit d9b5cedd authored by helloyongyang's avatar helloyongyang
Browse files

Simplify runner

parent 7d109a7c
......@@ -140,7 +140,7 @@ class DefaultRunner(BaseRunner):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
img = Image.open(self.config["image_path"]).convert("RGB")
clip_encoder_out = self.run_image_encoder(img)
vae_encode_out, kwargs = self.run_vae_encoder(img)
vae_encode_out = self.run_vae_encoder(img)
text_encoder_output = self.run_text_encoder(prompt, img)
torch.cuda.empty_cache()
gc.collect()
......@@ -158,7 +158,7 @@ class DefaultRunner(BaseRunner):
}
@ProfilingContext("Run DiT")
def _run_dit_local(self, kwargs):
def _run_dit_local(self):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.model = self.load_transformer()
self.init_scheduler()
......@@ -205,9 +205,9 @@ class DefaultRunner(BaseRunner):
self.inputs = self.run_input_encoder()
kwargs = self.set_target_shape()
self.set_target_shape()
latents, generator = self.run_dit(kwargs)
latents, generator = self.run_dit()
images = self.run_vae_decoder(latents, generator)
......
......@@ -202,7 +202,6 @@ class WanRunner(DefaultRunner):
return clip_encoder_out
def run_vae_encoder(self, img):
kwargs = {}
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
h, w = img.shape[1:]
aspect_ratio = h / w
......@@ -212,8 +211,7 @@ class WanRunner(DefaultRunner):
h = lat_h * self.config.vae_stride[1]
w = lat_w * self.config.vae_stride[2]
self.config.lat_h, kwargs["lat_h"] = lat_h, lat_h
self.config.lat_w, kwargs["lat_w"] = lat_w, lat_w
self.config.lat_h, self.config.lat_w = lat_h, lat_w
msk = torch.ones(
1,
......@@ -245,7 +243,7 @@ class WanRunner(DefaultRunner):
torch.cuda.empty_cache()
gc.collect()
vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16)
return vae_encode_out, kwargs
return vae_encode_out
def get_encoder_output_i2v(self, clip_encoder_out, vae_encode_out, text_encoder_output, img):
image_encoder_output = {
......@@ -258,7 +256,6 @@ class WanRunner(DefaultRunner):
}
def set_target_shape(self):
ret = {}
num_channels_latents = self.config.get("num_channels_latents", 16)
if self.config.task == "i2v":
self.config.target_shape = (
......@@ -267,8 +264,6 @@ class WanRunner(DefaultRunner):
self.config.lat_h,
self.config.lat_w,
)
ret["lat_h"] = self.config.lat_h
ret["lat_w"] = self.config.lat_w
elif self.config.task == "t2v":
self.config.target_shape = (
num_channels_latents,
......@@ -276,8 +271,6 @@ class WanRunner(DefaultRunner):
int(self.config.target_height) // self.config.vae_stride[1],
int(self.config.target_width) // self.config.vae_stride[2],
)
ret["target_shape"] = self.config.target_shape
return ret
def save_video_func(self, images):
cache_video(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment