Unverified Commit a7cc468f authored by Ilmari Heikkinen's avatar Ilmari Heikkinen Committed by GitHub
Browse files

AutoencoderKL: clamp indices of blend_h and blend_v to input size (#2660)

parent 07a0c1cb
......@@ -190,12 +190,12 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
return DecoderOutput(sample=decoded)
def blend_v(self, a, b, blend_extent):
for y in range(blend_extent):
for y in range(min(a.shape[2], b.shape[2], blend_extent)):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b
def blend_h(self, a, b, blend_extent):
for x in range(blend_extent):
for x in range(min(a.shape[3], b.shape[3], blend_extent)):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b
......
......@@ -445,6 +445,12 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 5e-1
# test that tiled decode works with various shapes
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
for shape in shapes:
zeros = torch.zeros(shape).to(device)
sd_pipe.vae.decode(zeros)
def test_stable_diffusion_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment