Unverified Commit 9ae90593 authored by Chi's avatar Chi Committed by GitHub
Browse files

Replacing the nn.Mish activation function with a get_activation function. (#5651)



* I added a new doc string to the class. This is more flexible to understanding other developers what are doing and where it's using.

* Update src/diffusers/models/unet_2d_blocks.py

This changes suggest by maintener.
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/models/unet_2d_blocks.py

Add suggested text
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update unet_2d_blocks.py

I changed the Parameter to Args text.

* Update unet_2d_blocks.py

proper indentation set in this file.

* Update unet_2d_blocks.py

a little bit of change in the act_fun argument line.

* I run the black command to reformat style in the code

* Update unet_2d_blocks.py

similar doc-string add to have in the original diffusion repository.

* I removed the dummy variable defined in both the encoder and decoder.

* Now, I run black package to reformat my file

* Remove the redundant line from the adapter.py file.

* Black package using to reformated my file

* Replacing the nn.Mish activation function with a get_activation function allows developers to more easily choose the right activation function for their task. Additionally, removing redundant variables can improve code readability and maintainability.

* I try to fix this: Fast tests for PRs / Fast PyTorch Models & Schedulers CPU tests (pull_request)

* Update src/diffusers/models/resnet.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 7942bb8d
......@@ -778,16 +778,22 @@ class Conv1dBlock(nn.Module):
out_channels (`int`): Number of output channels.
kernel_size (`int` or `tuple`): Size of the convolving kernel.
n_groups (`int`, default `8`): Number of groups to separate the channels into.
activation (`str`, defaults `mish`): Name of the activation function.
"""
def __init__(
self, inp_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], n_groups: int = 8
self,
inp_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
n_groups: int = 8,
activation: str = "mish",
):
super().__init__()
self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
self.group_norm = nn.GroupNorm(n_groups, out_channels)
self.mish = nn.Mish()
self.mish = get_activation(activation)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
intermediate_repr = self.conv1d(inputs)
......@@ -808,16 +814,22 @@ class ResidualTemporalBlock1D(nn.Module):
out_channels (`int`): Number of output channels.
embed_dim (`int`): Embedding dimension.
kernel_size (`int` or `tuple`): Size of the convolving kernel.
activation (`str`, defaults `mish`): It is possible to choose the right activation function.
"""
def __init__(
self, inp_channels: int, out_channels: int, embed_dim: int, kernel_size: Union[int, Tuple[int, int]] = 5
self,
inp_channels: int,
out_channels: int,
embed_dim: int,
kernel_size: Union[int, Tuple[int, int]] = 5,
activation: str = "mish",
):
super().__init__()
self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
self.time_emb_act = nn.Mish()
self.time_emb_act = get_activation(activation)
self.time_emb = nn.Linear(embed_dim, out_channels)
self.residual_conv = (
......
......@@ -162,8 +162,8 @@ class VQModel(ModelMixin, ConfigMixin):
If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple`
is returned.
"""
x = sample
h = self.encode(x).latents
h = self.encode(sample).latents
dec = self.decode(h).sample
if not return_dict:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment