"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "144c3a8b7c336ce5c7de6a9cdc58f0fea3666f33"
Commit 3986741b authored by Patrick von Platen's avatar Patrick von Platen
Browse files

add another ldm fast test

parent 6846ee2a
...@@ -81,61 +81,62 @@ def Normalize(in_channels): ...@@ -81,61 +81,62 @@ def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class LinearAttention(nn.Module): #class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32): # def __init__(self, dim, heads=4, dim_head=32):
super().__init__() # super().__init__()
self.heads = heads # self.heads = heads
hidden_dim = dim_head * heads # hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) # self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1) # self.to_out = nn.Conv2d(hidden_dim, dim, 1)
#
def forward(self, x): # def forward(self, x):
b, c, h, w = x.shape # b, c, h, w = x.shape
qkv = self.to_qkv(x) # qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3) # q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
k = k.softmax(dim=-1) # import ipdb; ipdb.set_trace()
context = torch.einsum("bhdn,bhen->bhde", k, v) # k = k.softmax(dim=-1)
out = torch.einsum("bhde,bhdn->bhen", context, q) # context = torch.einsum("bhdn,bhen->bhde", k, v)
out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w) # out = torch.einsum("bhde,bhdn->bhen", context, q)
return self.to_out(out) # out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
# return self.to_out(out)
#
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels): #class SpatialSelfAttention(nn.Module):
super().__init__() # def __init__(self, in_channels):
self.in_channels = in_channels # super().__init__()
# self.in_channels = in_channels
self.norm = Normalize(in_channels) #
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) # self.norm = Normalize(in_channels)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) # self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) # self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) # self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
# self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x): #
h_ = x # def forward(self, x):
h_ = self.norm(h_) # h_ = x
q = self.q(h_) # h_ = self.norm(h_)
k = self.k(h_) # q = self.q(h_)
v = self.v(h_) # k = self.k(h_)
# v = self.v(h_)
#
# compute attention # compute attention
b, c, h, w = q.shape # b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b (h w) c") # q = rearrange(q, "b c h w -> b (h w) c")
k = rearrange(k, "b c h w -> b c (h w)") # k = rearrange(k, "b c h w -> b c (h w)")
w_ = torch.einsum("bij,bjk->bik", q, k) # w_ = torch.einsum("bij,bjk->bik", q, k)
#
w_ = w_ * (int(c) ** (-0.5)) # w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2) # w_ = torch.nn.functional.softmax(w_, dim=2)
#
# attend to values # attend to values
v = rearrange(v, "b c h w -> b c (h w)") # v = rearrange(v, "b c h w -> b c (h w)")
w_ = rearrange(w_, "b i j -> b j i") # w_ = rearrange(w_, "b i j -> b j i")
h_ = torch.einsum("bij,bjk->bik", v, w_) # h_ = torch.einsum("bij,bjk->bik", v, w_)
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) # h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
h_ = self.proj_out(h_) # h_ = self.proj_out(h_)
#
return x + h_ # return x + h_
#
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
......
...@@ -511,6 +511,28 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -511,6 +511,28 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
def test_output_pretrained_spatial_transformer(self):
model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy-spatial")
model.eval()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
context = torch.ones((1, 16, 64), dtype=torch.float32)
time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad():
output = model(noise, time_step, context=context)
output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
expected_output_slice = torch.tensor([61.3445, 56.9005, 29.4339, 59.5497, 60.7375, 34.1719, 48.1951, 42.6569, 25.0890])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase): class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNetGradTTSModel model_class = UNetGradTTSModel
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment