Commit f35387b3 authored by patil-suraj's avatar patil-suraj
Browse files

clean Linear

parent 3e2cff4d
...@@ -157,6 +157,13 @@ def Conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale ...@@ -157,6 +157,13 @@ def Conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale
return conv return conv
def Linear(dim_in, dim_out):
linear = nn.Linear(dim_in, dim_out)
linear.weight.data = _variance_scaling()(linear.weight.shape)
nn.init.zeros_(linear.bias)
return linear
class Combine(nn.Module): class Combine(nn.Module):
"""Combine information from skip connections.""" """Combine information from skip connections."""
...@@ -296,13 +303,8 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -296,13 +303,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
else: else:
raise ValueError(f"embedding type {embedding_type} unknown.") raise ValueError(f"embedding type {embedding_type} unknown.")
if conditional: modules.append(Linear(embed_dim, nf * 4))
modules.append(nn.Linear(embed_dim, nf * 4)) modules.append(Linear(nf * 4, nf * 4))
modules[-1].weight.data = _variance_scaling()(modules[-1].weight.shape)
nn.init.zeros_(modules[-1].bias)
modules.append(nn.Linear(nf * 4, nf * 4))
modules[-1].weight.data = _variance_scaling()(modules[-1].weight.shape)
nn.init.zeros_(modules[-1].bias)
AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0)) AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
Up_sample = functools.partial(FirUpsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel) Up_sample = functools.partial(FirUpsample, with_conv=resamp_with_conv, fir_kernel=fir_kernel)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment