"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "d57f929da98f927365bdef303d3453712067e15b"
Unverified Commit 8b0bc596 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Merge pull request #52 from huggingface/clean-unet-sde

Clean UNetNCSNpp
parents 7e0fd19f f35387b3
...@@ -579,7 +579,6 @@ class ResnetBlockBigGANpp(nn.Module): ...@@ -579,7 +579,6 @@ class ResnetBlockBigGANpp(nn.Module):
up=False, up=False,
down=False, down=False,
dropout=0.1, dropout=0.1,
fir=False,
fir_kernel=(1, 3, 3, 1), fir_kernel=(1, 3, 3, 1),
skip_rescale=True, skip_rescale=True,
init_scale=0.0, init_scale=0.0,
...@@ -590,20 +589,20 @@ class ResnetBlockBigGANpp(nn.Module): ...@@ -590,20 +589,20 @@ class ResnetBlockBigGANpp(nn.Module):
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.up = up self.up = up
self.down = down self.down = down
self.fir = fir
self.fir_kernel = fir_kernel self.fir_kernel = fir_kernel
self.Conv_0 = conv3x3(in_ch, out_ch) self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None: if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch) self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape) self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
nn.init.zeros_(self.Dense_0.bias) nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout) self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1)
if in_ch != out_ch or up or down: if in_ch != out_ch or up or down:
self.Conv_2 = conv1x1(in_ch, out_ch) # 1x1 convolution with DDPM initialization.
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
self.skip_rescale = skip_rescale self.skip_rescale = skip_rescale
self.act = act self.act = act
...@@ -614,19 +613,11 @@ class ResnetBlockBigGANpp(nn.Module): ...@@ -614,19 +613,11 @@ class ResnetBlockBigGANpp(nn.Module):
h = self.act(self.GroupNorm_0(x)) h = self.act(self.GroupNorm_0(x))
if self.up: if self.up:
if self.fir: h = upsample_2d(h, self.fir_kernel, factor=2)
h = upsample_2d(h, self.fir_kernel, factor=2) x = upsample_2d(x, self.fir_kernel, factor=2)
x = upsample_2d(x, self.fir_kernel, factor=2)
else:
h = naive_upsample_2d(h, factor=2)
x = naive_upsample_2d(x, factor=2)
elif self.down: elif self.down:
if self.fir: h = downsample_2d(h, self.fir_kernel, factor=2)
h = downsample_2d(h, self.fir_kernel, factor=2) x = downsample_2d(x, self.fir_kernel, factor=2)
x = downsample_2d(x, self.fir_kernel, factor=2)
else:
h = naive_downsample_2d(h, factor=2)
x = naive_downsample_2d(x, factor=2)
h = self.Conv_0(h) h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding # Add bias to each feature map conditioned on the time embedding
...@@ -645,62 +636,6 @@ class ResnetBlockBigGANpp(nn.Module): ...@@ -645,62 +636,6 @@ class ResnetBlockBigGANpp(nn.Module):
return (x + h) / np.sqrt(2.0) return (x + h) / np.sqrt(2.0)
# unet_score_estimation.py
class ResnetBlockDDPMpp(nn.Module):
"""ResBlock adapted from DDPM."""
def __init__(
self,
act,
in_ch,
out_ch=None,
temb_dim=None,
conv_shortcut=False,
dropout=0.1,
skip_rescale=False,
init_scale=0.0,
):
super().__init__()
out_ch = out_ch if out_ch else in_ch
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
self.Conv_0 = conv3x3(in_ch, out_ch)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
if in_ch != out_ch:
if conv_shortcut:
self.Conv_2 = conv3x3(in_ch, out_ch)
else:
self.NIN_0 = NIN(in_ch, out_ch)
self.skip_rescale = skip_rescale
self.act = act
self.out_ch = out_ch
self.conv_shortcut = conv_shortcut
def forward(self, x, temb=None):
h = self.act(self.GroupNorm_0(x))
h = self.Conv_0(h)
if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None]
h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h)
h = self.Conv_1(h)
if x.shape[1] != self.out_ch:
if self.conv_shortcut:
x = self.Conv_2(x)
else:
x = self.NIN_0(x)
if not self.skip_rescale:
return x + h
else:
return (x + h) / np.sqrt(2.0)
# unet_rl.py # unet_rl.py
class ResidualTemporalBlock(nn.Module): class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5): def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
...@@ -818,32 +753,17 @@ class RearrangeDim(nn.Module): ...@@ -818,32 +753,17 @@ class RearrangeDim(nn.Module):
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0): def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale=1.0, padding=1):
"""1x1 convolution with DDPM initialization.""" """nXn convolution with DDPM initialization."""
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias) conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape) conv.weight.data = variance_scaling(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias) nn.init.zeros_(conv.bias)
return conv return conv
def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1): def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""3x3 convolution with DDPM initialization."""
conv = nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias
)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def default_init(scale=1.0):
"""The same initialization used in DDPM."""
scale = 1e-10 if scale == 0 else scale
return variance_scaling(scale, "fan_avg", "uniform")
def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
"""Ported from JAX.""" """Ported from JAX."""
scale = 1e-10 if scale == 0 else scale
def _compute_fans(shape, in_axis=1, out_axis=0): def _compute_fans(shape, in_axis=1, out_axis=0):
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
...@@ -853,21 +773,9 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor ...@@ -853,21 +773,9 @@ def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=tor
def init(shape, dtype=dtype, device=device): def init(shape, dtype=dtype, device=device):
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
if mode == "fan_in": denominator = (fan_in + fan_out) / 2
denominator = fan_in
elif mode == "fan_out":
denominator = fan_out
elif mode == "fan_avg":
denominator = (fan_in + fan_out) / 2
else:
raise ValueError("invalid mode for variance scaling initializer: {}".format(mode))
variance = scale / denominator variance = scale / denominator
if distribution == "normal": return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
elif distribution == "uniform":
return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance)
else:
raise ValueError("invalid distribution for variance scaling initializer")
return init return init
...@@ -965,31 +873,6 @@ def downsample_2d(x, k=None, factor=2, gain=1): ...@@ -965,31 +873,6 @@ def downsample_2d(x, k=None, factor=2, gain=1):
return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
def naive_upsample_2d(x, factor=2):
_N, C, H, W = x.shape
x = torch.reshape(x, (-1, C, H, 1, W, 1))
x = x.repeat(1, 1, 1, factor, 1, factor)
return torch.reshape(x, (-1, C, H * factor, W * factor))
def naive_downsample_2d(x, factor=2):
_N, C, H, W = x.shape
x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
return torch.mean(x, dim=(3, 5))
class NIN(nn.Module):
def __init__(self, in_dim, num_units, init_scale=0.1):
super().__init__()
self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
def forward(self, x):
x = x.permute(0, 2, 3, 1)
y = contract_inner(x, self.W) + self.b
return y.permute(0, 3, 1, 2)
def _setup_kernel(k): def _setup_kernel(k):
k = np.asarray(k, dtype=np.float32) k = np.asarray(k, dtype=np.float32)
if k.ndim == 1: if k.ndim == 1:
...@@ -998,17 +881,3 @@ def _setup_kernel(k): ...@@ -998,17 +881,3 @@ def _setup_kernel(k):
assert k.ndim == 2 assert k.ndim == 2
assert k.shape[0] == k.shape[1] assert k.shape[0] == k.shape[1]
return k return k
def contract_inner(x, y):
"""tensordot(x, y, 1)."""
x_chars = list(string.ascii_lowercase[: len(x.shape)])
y_chars = list(string.ascii_lowercase[len(x.shape) : len(y.shape) + len(x.shape)])
y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
out_chars = x_chars[:-1] + y_chars[1:]
return _einsum(x_chars, y_chars, out_chars, x, y)
def _einsum(a, b, c, x, y):
einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c))
return torch.einsum(einsum_str, x, y)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment