Commit bc72d297 authored by patil-suraj's avatar patil-suraj
Browse files

make Diffwave subclass of ModelMixin

parent 86da45bc
...@@ -19,6 +19,8 @@ import torch.nn as nn ...@@ -19,6 +19,8 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import tqdm import tqdm
from ..modeling_utils import ModelMixin
from ..configuration_utils import ConfigMixin
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -209,14 +211,35 @@ class ResidualGroup(nn.Module): ...@@ -209,14 +211,35 @@ class ResidualGroup(nn.Module):
return skip * math.sqrt(1.0 / self.num_res_layers) return skip * math.sqrt(1.0 / self.num_res_layers)
class DiffWave(nn.Module): class DiffWave(ModelMixin, ConfigMixin):
def __init__(self, in_channels, res_channels, skip_channels, out_channels, def __init__(
num_res_layers, dilation_cycle, self,
diffusion_step_embed_dim_in, in_channels=1,
diffusion_step_embed_dim_mid, res_channels=128,
diffusion_step_embed_dim_out): skip_channels=128,
out_channels=1,
num_res_layers=30,
dilation_cycle=10,
diffusion_step_embed_dim_in=128,
diffusion_step_embed_dim_mid=512,
diffusion_step_embed_dim_out=512,
):
super().__init__() super().__init__()
# register all init arguments with self.register
self.register(
in_channels=in_channels,
res_channels=res_channels,
skip_channels=skip_channels,
out_channels=out_channels,
num_res_layers=num_res_layers,
dilation_cycle=dilation_cycle,
diffusion_step_embed_dim_in=diffusion_step_embed_dim_in,
diffusion_step_embed_dim_mid=diffusion_step_embed_dim_mid,
diffusion_step_embed_dim_out=diffusion_step_embed_dim_out,
)
# Initial conv1x1 with relu # Initial conv1x1 with relu
self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU(inplace=False)) self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU(inplace=False))
# All residual layers # All residual layers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment