You need to sign in or sign up before continuing.
Unverified Commit e5d9baf0 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Merge pull request #38 from huggingface/one_attentino_module

Unify attention modules
parents e47c97a4 c482d7bd
import math import math
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
# unet_grad_tts.py # unet_grad_tts.py
# TODO(Patrick) - weird linear attention layer. Check with: https://github.com/huawei-noah/Speech-Backbones/issues/15
class LinearAttention(torch.nn.Module): class LinearAttention(torch.nn.Module):
def __init__(self, dim, heads=4, dim_head=32): def __init__(self, dim, heads=4, dim_head=32):
super(LinearAttention, self).__init__() super(LinearAttention, self).__init__()
...@@ -18,7 +18,6 @@ class LinearAttention(torch.nn.Module): ...@@ -18,7 +18,6 @@ class LinearAttention(torch.nn.Module):
def forward(self, x): def forward(self, x):
b, c, h, w = x.shape b, c, h, w = x.shape
qkv = self.to_qkv(x) qkv = self.to_qkv(x)
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
q, k, v = ( q, k, v = (
qkv.reshape(b, 3, self.heads, self.dim_head, h, w) qkv.reshape(b, 3, self.heads, self.dim_head, h, w)
.permute(1, 0, 2, 3, 4, 5) .permute(1, 0, 2, 3, 4, 5)
...@@ -27,12 +26,11 @@ class LinearAttention(torch.nn.Module): ...@@ -27,12 +26,11 @@ class LinearAttention(torch.nn.Module):
k = k.softmax(dim=-1) k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v) context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q) out = torch.einsum("bhde,bhdn->bhen", context, q)
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w) out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w)
return self.to_out(out) return self.to_out(out)
# unet_glide.py & unet_ldm.py # the main attention block that is used for all models
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
""" """
An attention block that allows spatial positions to attend to each other. An attention block that allows spatial positions to attend to each other.
...@@ -46,10 +44,13 @@ class AttentionBlock(nn.Module): ...@@ -46,10 +44,13 @@ class AttentionBlock(nn.Module):
channels, channels,
num_heads=1, num_heads=1,
num_head_channels=-1, num_head_channels=-1,
num_groups=32,
use_checkpoint=False, use_checkpoint=False,
encoder_channels=None, encoder_channels=None,
use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete? use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete?
overwrite_qkv=False, overwrite_qkv=False,
overwrite_linear=False,
rescale_output_factor=1.0,
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
...@@ -62,41 +63,67 @@ class AttentionBlock(nn.Module): ...@@ -62,41 +63,67 @@ class AttentionBlock(nn.Module):
self.num_heads = channels // num_head_channels self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.norm = normalization(channels, swish=0.0) self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-5, affine=True)
self.qkv = conv_nd(1, channels, channels * 3, 1) self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.n_heads = self.num_heads self.n_heads = self.num_heads
self.rescale_output_factor = rescale_output_factor
if encoder_channels is not None: if encoder_channels is not None:
self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1) self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
self.overwrite_qkv = overwrite_qkv self.overwrite_qkv = overwrite_qkv
if overwrite_qkv: if overwrite_qkv:
in_channels = channels in_channels = channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.overwrite_linear = overwrite_linear
if self.overwrite_linear:
num_groups = min(channels // 4, 32)
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.NIN_0 = NIN(channels, channels)
self.NIN_1 = NIN(channels, channels)
self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels)
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
self.is_overwritten = False self.is_overwritten = False
def set_weights(self, module): def set_weights(self, module):
if self.overwrite_qkv: if self.overwrite_qkv:
qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[:, :, :, 0] qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[
:, :, :, 0
]
qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0) qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0)
self.qkv.weight.data = qkv_weight self.qkv.weight.data = qkv_weight
self.qkv.bias.data = qkv_bias self.qkv.bias.data = qkv_bias
proj_out = zero_module(conv_nd(1, self.channels, self.channels, 1)) proj_out = zero_module(nn.Conv1d(self.channels, self.channels, 1))
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0] proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
proj_out.bias.data = module.proj_out.bias.data proj_out.bias.data = module.proj_out.bias.data
self.proj_out = proj_out self.proj_out = proj_out
elif self.overwrite_linear:
self.qkv.weight.data = torch.concat(
[self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0
)[:, :, None]
self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None]
self.proj_out.bias.data = self.NIN_3.b.data
self.norm.weight.data = self.GroupNorm_0.weight.data
self.norm.bias.data = self.GroupNorm_0.bias.data
def forward(self, x, encoder_out=None): def forward(self, x, encoder_out=None):
if self.overwrite_qkv and not self.is_overwritten: if (self.overwrite_qkv or self.overwrite_linear) and not self.is_overwritten:
self.set_weights(self) self.set_weights(self)
self.is_overwritten = True self.is_overwritten = True
...@@ -124,69 +151,74 @@ class AttentionBlock(nn.Module): ...@@ -124,69 +151,74 @@ class AttentionBlock(nn.Module):
h = a.reshape(bs, -1, length) h = a.reshape(bs, -1, length)
h = self.proj_out(h) h = self.proj_out(h)
h = h.reshape(b, c, *spatial)
return x + h.reshape(b, c, *spatial) result = x + h
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class GroupNorm32(nn.GroupNorm):
def __init__(self, num_groups, num_channels, swish, eps=1e-5, affine=True):
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, affine=affine)
self.swish = swish
def forward(self, x):
y = super().forward(x.float()).to(x.dtype)
if self.swish == 1.0:
y = F.silu(y)
elif self.swish:
y = y * F.sigmoid(y * float(self.swish))
return y
def normalization(channels, swish=0.0, eps=1e-5):
"""
Make a standard normalization layer, with an optional swish activation.
:param channels: number of input channels. :return: an nn.Module for normalization.
"""
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish, eps=eps, affine=True)
result = result / self.rescale_output_factor
def zero_module(module): return result
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
# unet_score_estimation.py # unet_score_estimation.py
# class AttnBlockpp(nn.Module): # class AttnBlockpp(nn.Module):
# """Channel-wise self-attention block. Modified from DDPM.""" # """Channel-wise self-attention block. Modified from DDPM."""
# #
# def __init__(self, channels, skip_rescale=False, init_scale=0.0): # def __init__(
# self,
# channels,
# skip_rescale=False,
# init_scale=0.0,
# num_heads=1,
# num_head_channels=-1,
# use_checkpoint=False,
# encoder_channels=None,
# use_new_attention_order=False, # TODO(Patrick) -> is never used, maybe delete?
# overwrite_qkv=False,
# overwrite_from_grad_tts=False,
# ):
# super().__init__() # super().__init__()
# self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6) # num_groups = min(channels // 4, 32)
# self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
# self.NIN_0 = NIN(channels, channels) # self.NIN_0 = NIN(channels, channels)
# self.NIN_1 = NIN(channels, channels) # self.NIN_1 = NIN(channels, channels)
# self.NIN_2 = NIN(channels, channels) # self.NIN_2 = NIN(channels, channels)
# self.NIN_3 = NIN(channels, channels, init_scale=init_scale) # self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
# self.skip_rescale = skip_rescale # self.skip_rescale = skip_rescale
# #
# self.channels = channels
# if num_head_channels == -1:
# self.num_heads = num_heads
# else:
# assert (
# channels % num_head_channels == 0
# ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
# self.num_heads = channels // num_head_channels
#
# self.use_checkpoint = use_checkpoint
# self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
# self.qkv = nn.Conv1d(channels, channels * 3, 1)
# self.n_heads = self.num_heads
#
# self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
#
# self.is_weight_set = False
#
# def set_weights(self):
# self.qkv.weight.data = torch.concat([self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0)[:, :, None]
# self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
#
# self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None]
# self.proj_out.bias.data = self.NIN_3.b.data
#
# self.norm.weight.data = self.GroupNorm_0.weight.data
# self.norm.bias.data = self.GroupNorm_0.bias.data
#
# def forward(self, x): # def forward(self, x):
# if not self.is_weight_set:
# self.set_weights()
# self.is_weight_set = True
#
# B, C, H, W = x.shape # B, C, H, W = x.shape
# h = self.GroupNorm_0(x) # h = self.GroupNorm_0(x)
# q = self.NIN_0(h) # q = self.NIN_0(h)
...@@ -199,7 +231,59 @@ def zero_module(module): ...@@ -199,7 +231,59 @@ def zero_module(module):
# w = torch.reshape(w, (B, H, W, H, W)) # w = torch.reshape(w, (B, H, W, H, W))
# h = torch.einsum("bhwij,bcij->bchw", w, v) # h = torch.einsum("bhwij,bcij->bchw", w, v)
# h = self.NIN_3(h) # h = self.NIN_3(h)
#
# if not self.skip_rescale: # if not self.skip_rescale:
# return x + h # result = x + h
# else: # else:
# return (x + h) / np.sqrt(2.0) # result = (x + h) / np.sqrt(2.0)
#
# result = self.forward_2(x)
#
# return result
#
# def forward_2(self, x, encoder_out=None):
# b, c, *spatial = x.shape
# hid_states = self.norm(x).view(b, c, -1)
#
# qkv = self.qkv(hid_states)
# bs, width, length = qkv.shape
# assert width % (3 * self.n_heads) == 0
# ch = width // (3 * self.n_heads)
# q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
#
# if encoder_out is not None:
# encoder_kv = self.encoder_kv(encoder_out)
# assert encoder_kv.shape[1] == self.n_heads * ch * 2
# ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
# k = torch.cat([ek, k], dim=-1)
# v = torch.cat([ev, v], dim=-1)
#
# scale = 1 / math.sqrt(math.sqrt(ch))
# weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
# weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
#
# a = torch.einsum("bts,bcs->bct", weight, v)
# h = a.reshape(bs, -1, length)
#
# h = self.proj_out(h)
# h = h.reshape(b, c, *spatial)
#
# return (x + h) / np.sqrt(2.0)
# TODO(Patrick) - this can and should be removed
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
class NIN(nn.Module):
def __init__(self, in_dim, num_units, init_scale=0.1):
super().__init__()
self.W = nn.Parameter(torch.zeros(in_dim, num_units), requires_grad=True)
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
...@@ -65,19 +65,3 @@ class GaussianFourierProjection(nn.Module): ...@@ -65,19 +65,3 @@ class GaussianFourierProjection(nn.Module):
def forward(self, x): def forward(self, x):
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
# unet_rl.py - TODO(need test)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
...@@ -15,24 +15,14 @@ ...@@ -15,24 +15,14 @@
# helpers functions # helpers functions
import copy
import math
from pathlib import Path
import torch import torch
from torch import nn from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.optim import Adam
from torch.utils import data
from PIL import Image
from tqdm import tqdm
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample from .resnet import Downsample, Upsample
from .attention2d import AttentionBlock
def nonlinearity(x): def nonlinearity(x):
...@@ -219,11 +209,7 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -219,11 +209,7 @@ class UNetModel(ModelMixin, ConfigMixin):
for i_block in range(self.num_res_blocks): for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb) h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0: if len(self.down[i_level].attn) > 0:
# self.down[i_level].attn_2[i_block].set_weights(self.down[i_level].attn[i_block])
# h = self.down[i_level].attn_2[i_block](h)
h = self.down[i_level].attn[i_block](h) h = self.down[i_level].attn[i_block](h)
# print("Result", (h - h_2).abs().sum())
hs.append(h) hs.append(h)
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1])) hs.append(self.down[i_level].downsample(hs[-1]))
......
...@@ -6,7 +6,7 @@ import torch.nn.functional as F ...@@ -6,7 +6,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention2d import AttentionBlock from .attention import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample from .resnet import Downsample, Upsample
......
import torch import torch
from numpy import pad
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import LinearAttention
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample from .resnet import Downsample, Upsample
...@@ -54,32 +54,6 @@ class ResnetBlock(torch.nn.Module): ...@@ -54,32 +54,6 @@ class ResnetBlock(torch.nn.Module):
return output return output
class LinearAttention(torch.nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super(LinearAttention, self).__init__()
self.heads = heads
self.dim_head = dim_head
hidden_dim = dim_head * heads
self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
# q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
q, k, v = (
qkv.reshape(b, 3, self.heads, self.dim_head, h, w)
.permute(1, 0, 2, 3, 4, 5)
.reshape(3, b, self.heads, self.dim_head, -1)
)
k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
# out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w)
return self.to_out(out)
class Residual(torch.nn.Module): class Residual(torch.nn.Module):
def __init__(self, fn): def __init__(self, fn):
super(Residual, self).__init__() super(Residual, self).__init__()
......
...@@ -9,7 +9,7 @@ import torch.nn.functional as F ...@@ -9,7 +9,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention2d import AttentionBlock from .attention import AttentionBlock
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample, Upsample from .resnet import Downsample, Upsample
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# helpers functions # helpers functions
import functools import functools
import math
import string import string
import numpy as np import numpy as np
...@@ -25,6 +26,7 @@ import torch.nn.functional as F ...@@ -25,6 +26,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding from .embeddings import GaussianFourierProjection, get_timestep_embedding
...@@ -414,37 +416,6 @@ class Combine(nn.Module): ...@@ -414,37 +416,6 @@ class Combine(nn.Module):
raise ValueError(f"Method {self.method} not recognized.") raise ValueError(f"Method {self.method} not recognized.")
class AttnBlockpp(nn.Module):
"""Channel-wise self-attention block. Modified from DDPM."""
def __init__(self, channels, skip_rescale=False, init_scale=0.0):
super().__init__()
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6)
self.NIN_0 = NIN(channels, channels)
self.NIN_1 = NIN(channels, channels)
self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
self.skip_rescale = skip_rescale
def forward(self, x):
B, C, H, W = x.shape
h = self.GroupNorm_0(x)
q = self.NIN_0(h)
k = self.NIN_1(h)
v = self.NIN_2(h)
w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5))
w = torch.reshape(w, (B, H, W, H * W))
w = F.softmax(w, dim=-1)
w = torch.reshape(w, (B, H, W, H, W))
h = torch.einsum("bhwij,bcij->bchw", w, v)
h = self.NIN_3(h)
if not self.skip_rescale:
return x + h
else:
return (x + h) / np.sqrt(2.0)
class Upsample(nn.Module): class Upsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)): def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
super().__init__() super().__init__()
...@@ -756,8 +727,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -756,8 +727,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
modules[-1].weight.data = default_init()(modules[-1].weight.shape) modules[-1].weight.data = default_init()(modules[-1].weight.shape)
nn.init.zeros_(modules[-1].bias) nn.init.zeros_(modules[-1].bias)
AttnBlock = functools.partial(AttnBlockpp, init_scale=init_scale, skip_rescale=skip_rescale) AttnBlock = functools.partial(AttentionBlock, overwrite_linear=True, rescale_output_factor=math.sqrt(2.0))
Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
if progressive == "output_skip": if progressive == "output_skip":
......
...@@ -859,7 +859,9 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -859,7 +859,9 @@ class PipelineTesterMixin(unittest.TestCase):
image_slice = image[0, -1, -3:, -3:].cpu() image_slice = image[0, -1, -3:, -3:].cpu()
assert image.shape == (1, 3, 32, 32) assert image.shape == (1, 3, 32, 32)
expected_slice = torch.tensor([-0.5712, -0.6215, -0.5953, -0.5438, -0.4775, -0.4539, -0.5172, -0.4872, -0.5105]) expected_slice = torch.tensor(
[-0.5712, -0.6215, -0.5953, -0.5438, -0.4775, -0.4539, -0.5172, -0.4872, -0.5105]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow @slow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment