Unverified Commit 9f3c0fdc authored by Pavle Padjin's avatar Pavle Padjin Committed by GitHub
Browse files

Avoiding graph break by changing the way we infer dtype in vae.decoder (#12512)

* Changing the way we infer dtype to avoid force evaluation of lazy tensors

* changing way to infer dtype to ensure type consistency

* more robust infering of dtype

* removing the upscale dtype entirely
parent 84e16575
...@@ -286,11 +286,9 @@ class Decoder(nn.Module): ...@@ -286,11 +286,9 @@ class Decoder(nn.Module):
sample = self.conv_in(sample) sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
# middle # middle
sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds) sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
sample = sample.to(upscale_dtype)
# up # up
for up_block in self.up_blocks: for up_block in self.up_blocks:
...@@ -298,7 +296,6 @@ class Decoder(nn.Module): ...@@ -298,7 +296,6 @@ class Decoder(nn.Module):
else: else:
# middle # middle
sample = self.mid_block(sample, latent_embeds) sample = self.mid_block(sample, latent_embeds)
sample = sample.to(upscale_dtype)
# up # up
for up_block in self.up_blocks: for up_block in self.up_blocks:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment