Commit bc72d297 authored by patil-suraj's avatar patil-suraj
Browse files

make Diffwave subclass of ModelMixin

parent 86da45bc
......@@ -19,6 +19,8 @@ import torch.nn as nn
import torch.nn.functional as F
import tqdm
from ..modeling_utils import ModelMixin
from ..configuration_utils import ConfigMixin
from ..pipeline_utils import DiffusionPipeline
......@@ -209,14 +211,35 @@ class ResidualGroup(nn.Module):
return skip * math.sqrt(1.0 / self.num_res_layers)
class DiffWave(nn.Module):
def __init__(self, in_channels, res_channels, skip_channels, out_channels,
num_res_layers, dilation_cycle,
diffusion_step_embed_dim_in,
diffusion_step_embed_dim_mid,
diffusion_step_embed_dim_out):
class DiffWave(ModelMixin, ConfigMixin):
def __init__(
self,
in_channels=1,
res_channels=128,
skip_channels=128,
out_channels=1,
num_res_layers=30,
dilation_cycle=10,
diffusion_step_embed_dim_in=128,
diffusion_step_embed_dim_mid=512,
diffusion_step_embed_dim_out=512,
):
super().__init__()
# register all init arguments with self.register
self.register(
in_channels=in_channels,
res_channels=res_channels,
skip_channels=skip_channels,
out_channels=out_channels,
num_res_layers=num_res_layers,
dilation_cycle=dilation_cycle,
diffusion_step_embed_dim_in=diffusion_step_embed_dim_in,
diffusion_step_embed_dim_mid=diffusion_step_embed_dim_mid,
diffusion_step_embed_dim_out=diffusion_step_embed_dim_out,
)
# Initial conv1x1 with relu
self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU(inplace=False))
# All residual layers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment