Unverified Commit b32321e4 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Fix] fix wanti2v-t2v vae encode bug (#408)

parent e24de2ec
...@@ -532,10 +532,10 @@ class Wan22DenseRunner(WanRunner): ...@@ -532,10 +532,10 @@ class Wan22DenseRunner(WanRunner):
# to tensor # to tensor
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda().unsqueeze(1) img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda().unsqueeze(1)
vae_encoder_out = self.get_vae_encoder_output(img) vae_encoder_out = self.get_vae_encoder_output(img)
self.config.lat_w, self.config.lat_h = ow // self.config.vae_stride[2], oh // self.config.vae_stride[1] latent_w, latent_h = ow // self.config["vae_stride"][2], oh // self.config["vae_stride"][1]
latent_shape = self.get_latent_shape_with_lat_hw(latent_h, latent_w)
return vae_encoder_out return vae_encoder_out, latent_shape
def get_vae_encoder_output(self, img): def get_vae_encoder_output(self, img):
z = self.vae_encoder.encode(img.to(GET_DTYPE())) z = self.vae_encoder.encode(img.unsqueeze(0).to(GET_DTYPE()))
return z return z
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment