Commit d9b5cedd authored by helloyongyang's avatar helloyongyang
Browse files

Simplify runner

parent 7d109a7c
...@@ -140,7 +140,7 @@ class DefaultRunner(BaseRunner): ...@@ -140,7 +140,7 @@ class DefaultRunner(BaseRunner):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
img = Image.open(self.config["image_path"]).convert("RGB") img = Image.open(self.config["image_path"]).convert("RGB")
clip_encoder_out = self.run_image_encoder(img) clip_encoder_out = self.run_image_encoder(img)
vae_encode_out, kwargs = self.run_vae_encoder(img) vae_encode_out = self.run_vae_encoder(img)
text_encoder_output = self.run_text_encoder(prompt, img) text_encoder_output = self.run_text_encoder(prompt, img)
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
...@@ -158,7 +158,7 @@ class DefaultRunner(BaseRunner): ...@@ -158,7 +158,7 @@ class DefaultRunner(BaseRunner):
} }
@ProfilingContext("Run DiT") @ProfilingContext("Run DiT")
def _run_dit_local(self, kwargs): def _run_dit_local(self):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.model = self.load_transformer() self.model = self.load_transformer()
self.init_scheduler() self.init_scheduler()
...@@ -205,9 +205,9 @@ class DefaultRunner(BaseRunner): ...@@ -205,9 +205,9 @@ class DefaultRunner(BaseRunner):
self.inputs = self.run_input_encoder() self.inputs = self.run_input_encoder()
kwargs = self.set_target_shape() self.set_target_shape()
latents, generator = self.run_dit(kwargs) latents, generator = self.run_dit()
images = self.run_vae_decoder(latents, generator) images = self.run_vae_decoder(latents, generator)
......
...@@ -202,7 +202,6 @@ class WanRunner(DefaultRunner): ...@@ -202,7 +202,6 @@ class WanRunner(DefaultRunner):
return clip_encoder_out return clip_encoder_out
def run_vae_encoder(self, img): def run_vae_encoder(self, img):
kwargs = {}
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda() img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
h, w = img.shape[1:] h, w = img.shape[1:]
aspect_ratio = h / w aspect_ratio = h / w
...@@ -212,8 +211,7 @@ class WanRunner(DefaultRunner): ...@@ -212,8 +211,7 @@ class WanRunner(DefaultRunner):
h = lat_h * self.config.vae_stride[1] h = lat_h * self.config.vae_stride[1]
w = lat_w * self.config.vae_stride[2] w = lat_w * self.config.vae_stride[2]
self.config.lat_h, kwargs["lat_h"] = lat_h, lat_h self.config.lat_h, self.config.lat_w = lat_h, lat_w
self.config.lat_w, kwargs["lat_w"] = lat_w, lat_w
msk = torch.ones( msk = torch.ones(
1, 1,
...@@ -245,7 +243,7 @@ class WanRunner(DefaultRunner): ...@@ -245,7 +243,7 @@ class WanRunner(DefaultRunner):
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16) vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16)
return vae_encode_out, kwargs return vae_encode_out
def get_encoder_output_i2v(self, clip_encoder_out, vae_encode_out, text_encoder_output, img): def get_encoder_output_i2v(self, clip_encoder_out, vae_encode_out, text_encoder_output, img):
image_encoder_output = { image_encoder_output = {
...@@ -258,7 +256,6 @@ class WanRunner(DefaultRunner): ...@@ -258,7 +256,6 @@ class WanRunner(DefaultRunner):
} }
def set_target_shape(self): def set_target_shape(self):
ret = {}
num_channels_latents = self.config.get("num_channels_latents", 16) num_channels_latents = self.config.get("num_channels_latents", 16)
if self.config.task == "i2v": if self.config.task == "i2v":
self.config.target_shape = ( self.config.target_shape = (
...@@ -267,8 +264,6 @@ class WanRunner(DefaultRunner): ...@@ -267,8 +264,6 @@ class WanRunner(DefaultRunner):
self.config.lat_h, self.config.lat_h,
self.config.lat_w, self.config.lat_w,
) )
ret["lat_h"] = self.config.lat_h
ret["lat_w"] = self.config.lat_w
elif self.config.task == "t2v": elif self.config.task == "t2v":
self.config.target_shape = ( self.config.target_shape = (
num_channels_latents, num_channels_latents,
...@@ -276,8 +271,6 @@ class WanRunner(DefaultRunner): ...@@ -276,8 +271,6 @@ class WanRunner(DefaultRunner):
int(self.config.target_height) // self.config.vae_stride[1], int(self.config.target_height) // self.config.vae_stride[1],
int(self.config.target_width) // self.config.vae_stride[2], int(self.config.target_width) // self.config.vae_stride[2],
) )
ret["target_shape"] = self.config.target_shape
return ret
def save_video_func(self, images): def save_video_func(self, images):
cache_video( cache_video(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment