Commit f1cb8074 authored by patil-suraj's avatar patil-suraj
Browse files

remove get_act

parent 13ac40ed
...@@ -295,21 +295,6 @@ def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale ...@@ -295,21 +295,6 @@ def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale
return conv return conv
def get_act(nonlinearity):
"""Get activation functions from the config file."""
if nonlinearity.lower() == "elu":
return nn.ELU()
elif nonlinearity.lower() == "relu":
return nn.ReLU()
elif nonlinearity.lower() == "lrelu":
return nn.LeakyReLU(negative_slope=0.2)
elif nonlinearity.lower() == "swish":
return nn.SiLU()
else:
raise NotImplementedError("activation function does not exist!")
def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"): def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""Ported from JAX.""" """Ported from JAX."""
scale = 1e-10 if scale == 0 else scale scale = 1e-10 if scale == 0 else scale
...@@ -467,7 +452,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -467,7 +452,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
skip_rescale=skip_rescale, skip_rescale=skip_rescale,
continuous=continuous, continuous=continuous,
) )
self.act = act = get_act(nonlinearity) self.act = act = nn.SiLU()
self.nf = nf self.nf = nf
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment