Unverified Commit 9ae90593 authored by Chi's avatar Chi Committed by GitHub
Browse files

Replacing the nn.Mish activation function with a get_activation function. (#5651)



* I added a new doc string to the class. This is more flexible to understanding other developers what are doing and where it's using.

* Update src/diffusers/models/unet_2d_blocks.py

This changes suggest by maintener.
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/models/unet_2d_blocks.py

Add suggested text
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update unet_2d_blocks.py

I changed the Parameter to Args text.

* Update unet_2d_blocks.py

proper indentation set in this file.

* Update unet_2d_blocks.py

a little bit of change in the act_fun argument line.

* I run the black command to reformat style in the code

* Update unet_2d_blocks.py

similar doc-string add to have in the original diffusion repository.

* I removed the dummy variable defined in both the encoder and decoder.

* Now, I run black package to reformat my file

* Remove the redundant line from the adapter.py file.

* Black package using to reformated my file

* Replacing the nn.Mish activation function with a get_activation function allows developers to more easily choose the right activation function for their task. Additionally, removing redundant variables can improve code readability and maintainability.

* I try to fix this: Fast tests for PRs / Fast PyTorch Models & Schedulers CPU tests (pull_request)

* Update src/diffusers/models/resnet.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 7942bb8d
...@@ -778,16 +778,22 @@ class Conv1dBlock(nn.Module): ...@@ -778,16 +778,22 @@ class Conv1dBlock(nn.Module):
out_channels (`int`): Number of output channels. out_channels (`int`): Number of output channels.
kernel_size (`int` or `tuple`): Size of the convolving kernel. kernel_size (`int` or `tuple`): Size of the convolving kernel.
n_groups (`int`, default `8`): Number of groups to separate the channels into. n_groups (`int`, default `8`): Number of groups to separate the channels into.
activation (`str`, defaults `mish`): Name of the activation function.
""" """
def __init__( def __init__(
self, inp_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], n_groups: int = 8 self,
inp_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
n_groups: int = 8,
activation: str = "mish",
): ):
super().__init__() super().__init__()
self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2) self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
self.group_norm = nn.GroupNorm(n_groups, out_channels) self.group_norm = nn.GroupNorm(n_groups, out_channels)
self.mish = nn.Mish() self.mish = get_activation(activation)
def forward(self, inputs: torch.Tensor) -> torch.Tensor: def forward(self, inputs: torch.Tensor) -> torch.Tensor:
intermediate_repr = self.conv1d(inputs) intermediate_repr = self.conv1d(inputs)
...@@ -808,16 +814,22 @@ class ResidualTemporalBlock1D(nn.Module): ...@@ -808,16 +814,22 @@ class ResidualTemporalBlock1D(nn.Module):
out_channels (`int`): Number of output channels. out_channels (`int`): Number of output channels.
embed_dim (`int`): Embedding dimension. embed_dim (`int`): Embedding dimension.
kernel_size (`int` or `tuple`): Size of the convolving kernel. kernel_size (`int` or `tuple`): Size of the convolving kernel.
activation (`str`, defaults `mish`): It is possible to choose the right activation function.
""" """
def __init__( def __init__(
self, inp_channels: int, out_channels: int, embed_dim: int, kernel_size: Union[int, Tuple[int, int]] = 5 self,
inp_channels: int,
out_channels: int,
embed_dim: int,
kernel_size: Union[int, Tuple[int, int]] = 5,
activation: str = "mish",
): ):
super().__init__() super().__init__()
self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size) self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
self.time_emb_act = nn.Mish() self.time_emb_act = get_activation(activation)
self.time_emb = nn.Linear(embed_dim, out_channels) self.time_emb = nn.Linear(embed_dim, out_channels)
self.residual_conv = ( self.residual_conv = (
......
...@@ -162,8 +162,8 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -162,8 +162,8 @@ class VQModel(ModelMixin, ConfigMixin):
If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple` If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple`
is returned. is returned.
""" """
x = sample
h = self.encode(x).latents h = self.encode(sample).latents
dec = self.decode(h).sample dec = self.decode(h).sample
if not return_dict: if not return_dict:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment