Unverified Commit 29cf163b authored by Chi's avatar Chi Committed by GitHub
Browse files

Remove Redundant Variables from Encoder and Decoder (#5569)



* I added a new doc string to the class. This is more flexible to understanding other developers what are doing and where it's using.

* Update src/diffusers/models/unet_2d_blocks.py

This changes suggest by maintener.
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/models/unet_2d_blocks.py

Add suggested text
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update unet_2d_blocks.py

I changed the Parameter to Args text.

* Update unet_2d_blocks.py

proper indentation set in this file.

* Update unet_2d_blocks.py

a little bit of change in the act_fun argument line.

* I run the black command to reformat style in the code

* Update unet_2d_blocks.py

similar doc-string add to have in the original diffusion repository.

* I removed the dummy variable defined in both the encoder and decoder.

* Now, I run black package to reformat my file

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 839c2a5e
...@@ -130,9 +130,9 @@ class Encoder(nn.Module): ...@@ -130,9 +130,9 @@ class Encoder(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class.""" r"""The forward method of the `Encoder` class."""
sample = x
sample = self.conv_in(sample) sample = self.conv_in(sample)
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
...@@ -273,9 +273,11 @@ class Decoder(nn.Module): ...@@ -273,9 +273,11 @@ class Decoder(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward(self, z: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: def forward(
self, sample: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class.""" r"""The forward method of the `Decoder` class."""
sample = z
sample = self.conv_in(sample) sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment