Commit 82e6fa56 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

make style

parent edb087a2
......@@ -314,13 +314,13 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
inputs = self.get_inputs(torch_device)
# change input image to a random size (one that would cause a tensor mismatch error)
inputs['image'] = inputs['image'].resize((127,127))
inputs['mask_image'] = inputs['mask_image'].resize((127,127))
inputs['height'] = 128
inputs['width'] = 128
inputs["image"] = inputs["image"].resize((127, 127))
inputs["mask_image"] = inputs["mask_image"].resize((127, 127))
inputs["height"] = 128
inputs["width"] = 128
image = pipe(**inputs).images
# verify that the returned image has the same height and width as the input height and width
assert image.shape == (1, inputs['height'], inputs['width'], 3)
assert image.shape == (1, inputs["height"], inputs["width"], 3)
@nightly
......@@ -451,7 +451,18 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
im_np = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
im_pil = Image.fromarray(im_np)
mask_np = np.random.randint(0, 255, (height, width,), dtype=np.uint8) > 127.5
mask_np = (
np.random.randint(
0,
255,
(
height,
width,
),
dtype=np.uint8,
)
> 127.5
)
mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8))
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
......@@ -463,12 +474,34 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
def test_torch_3D_2D_inputs(self):
height, width = 32, 32
im_tensor = torch.randint(0, 255, (3, height, width,), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (height, width,), dtype=torch.uint8) > 127.5
im_tensor = torch.randint(
0,
255,
(
3,
height,
width,
),
dtype=torch.uint8,
)
mask_tensor = (
torch.randint(
0,
255,
(
height,
width,
),
dtype=torch.uint8,
)
> 127.5
)
im_np = im_tensor.numpy().transpose(1, 2, 0)
mask_np = mask_tensor.numpy()
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width
)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
self.assertTrue((t_mask_tensor == t_mask_np).all())
......@@ -477,12 +510,35 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
def test_torch_3D_3D_inputs(self):
height, width = 32, 32
im_tensor = torch.randint(0, 255, (3, height, width,), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (1, height, width,), dtype=torch.uint8) > 127.5
im_tensor = torch.randint(
0,
255,
(
3,
height,
width,
),
dtype=torch.uint8,
)
mask_tensor = (
torch.randint(
0,
255,
(
1,
height,
width,
),
dtype=torch.uint8,
)
> 127.5
)
im_np = im_tensor.numpy().transpose(1, 2, 0)
mask_np = mask_tensor.numpy()[0]
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width
)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
self.assertTrue((t_mask_tensor == t_mask_np).all())
......@@ -491,12 +547,35 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
def test_torch_4D_2D_inputs(self):
height, width = 32, 32
im_tensor = torch.randint(0, 255, (1, 3, height, width,), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (height, width,), dtype=torch.uint8) > 127.5
im_tensor = torch.randint(
0,
255,
(
1,
3,
height,
width,
),
dtype=torch.uint8,
)
mask_tensor = (
torch.randint(
0,
255,
(
height,
width,
),
dtype=torch.uint8,
)
> 127.5
)
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
mask_np = mask_tensor.numpy()
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width
)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
self.assertTrue((t_mask_tensor == t_mask_np).all())
......@@ -505,12 +584,36 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
def test_torch_4D_3D_inputs(self):
height, width = 32, 32
im_tensor = torch.randint(0, 255, (1, 3, height, width,), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (1, height, width,), dtype=torch.uint8) > 127.5
im_tensor = torch.randint(
0,
255,
(
1,
3,
height,
width,
),
dtype=torch.uint8,
)
mask_tensor = (
torch.randint(
0,
255,
(
1,
height,
width,
),
dtype=torch.uint8,
)
> 127.5
)
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
mask_np = mask_tensor.numpy()[0]
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width
)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
self.assertTrue((t_mask_tensor == t_mask_np).all())
......@@ -519,12 +622,37 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
def test_torch_4D_4D_inputs(self):
height, width = 32, 32
im_tensor = torch.randint(0, 255, (1, 3, height, width,), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (1, 1, height, width,), dtype=torch.uint8) > 127.5
im_tensor = torch.randint(
0,
255,
(
1,
3,
height,
width,
),
dtype=torch.uint8,
)
mask_tensor = (
torch.randint(
0,
255,
(
1,
1,
height,
width,
),
dtype=torch.uint8,
)
> 127.5
)
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
mask_np = mask_tensor.numpy()[0][0]
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width
)
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width)
self.assertTrue((t_mask_tensor == t_mask_np).all())
......@@ -533,13 +661,37 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
def test_torch_batch_4D_3D(self):
height, width = 32, 32
im_tensor = torch.randint(0, 255, (2, 3, height, width,), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (2, height, width,), dtype=torch.uint8) > 127.5
im_tensor = torch.randint(
0,
255,
(
2,
3,
height,
width,
),
dtype=torch.uint8,
)
mask_tensor = (
torch.randint(
0,
255,
(
2,
height,
width,
),
dtype=torch.uint8,
)
> 127.5
)
im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
mask_nps = [mask.numpy() for mask in mask_tensor]
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width
)
nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)]
t_mask_np = torch.cat([n[0] for n in nps])
t_masked_np = torch.cat([n[1] for n in nps])
......@@ -550,13 +702,38 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
def test_torch_batch_4D_4D(self):
height, width = 32, 32
im_tensor = torch.randint(0, 255, (2, 3, height, width,), dtype=torch.uint8)
mask_tensor = torch.randint(0, 255, (2, 1, height, width,), dtype=torch.uint8) > 127.5
im_tensor = torch.randint(
0,
255,
(
2,
3,
height,
width,
),
dtype=torch.uint8,
)
mask_tensor = (
torch.randint(
0,
255,
(
2,
1,
height,
width,
),
dtype=torch.uint8,
)
> 127.5
)
im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
mask_nps = [mask.numpy()[0] for mask in mask_tensor]
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor, height, width)
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(
im_tensor / 127.5 - 1, mask_tensor, height, width
)
nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)]
t_mask_np = torch.cat([n[0] for n in nps])
t_masked_np = torch.cat([n[1] for n in nps])
......@@ -569,43 +746,159 @@ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase)
# test height and width
with self.assertRaises(AssertionError):
prepare_mask_and_masked_image(torch.randn(3, height, width,), torch.randn(64, 64), height, width)
prepare_mask_and_masked_image(
torch.randn(
3,
height,
width,
),
torch.randn(64, 64),
height,
width,
)
# test batch dim
with self.assertRaises(AssertionError):
prepare_mask_and_masked_image(torch.randn(2, 3, height, width,), torch.randn(4, 64, 64), height, width)
prepare_mask_and_masked_image(
torch.randn(
2,
3,
height,
width,
),
torch.randn(4, 64, 64),
height,
width,
)
# test batch dim
with self.assertRaises(AssertionError):
prepare_mask_and_masked_image(torch.randn(2, 3, height, width,), torch.randn(4, 1, 64, 64), height, width)
prepare_mask_and_masked_image(
torch.randn(
2,
3,
height,
width,
),
torch.randn(4, 1, 64, 64),
height,
width,
)
def test_type_mismatch(self):
height, width = 32, 32
# test tensors-only
with self.assertRaises(TypeError):
prepare_mask_and_masked_image(torch.rand(3, height, width,), torch.rand(3, height, width,).numpy(), height, width)
prepare_mask_and_masked_image(
torch.rand(
3,
height,
width,
),
torch.rand(
3,
height,
width,
).numpy(),
height,
width,
)
# test tensors-only
with self.assertRaises(TypeError):
prepare_mask_and_masked_image(torch.rand(3, height, width,).numpy(), torch.rand(3, height, width,), height, width)
prepare_mask_and_masked_image(
torch.rand(
3,
height,
width,
).numpy(),
torch.rand(
3,
height,
width,
),
height,
width,
)
def test_channels_first(self):
height, width = 32, 32
# test channels first for 3D tensors
with self.assertRaises(AssertionError):
prepare_mask_and_masked_image(torch.rand(height, width, 3), torch.rand(3, height, width,), height, width)
prepare_mask_and_masked_image(
torch.rand(height, width, 3),
torch.rand(
3,
height,
width,
),
height,
width,
)
def test_tensor_range(self):
height, width = 32, 32
# test im <= 1
with self.assertRaises(ValueError):
prepare_mask_and_masked_image(torch.ones(3, height, width,) * 2, torch.rand(height, width,), height, width)
prepare_mask_and_masked_image(
torch.ones(
3,
height,
width,
)
* 2,
torch.rand(
height,
width,
),
height,
width,
)
# test im >= -1
with self.assertRaises(ValueError):
prepare_mask_and_masked_image(torch.ones(3, height, width,) * (-2), torch.rand(height, width,), height, width)
prepare_mask_and_masked_image(
torch.ones(
3,
height,
width,
)
* (-2),
torch.rand(
height,
width,
),
height,
width,
)
# test mask <= 1
with self.assertRaises(ValueError):
prepare_mask_and_masked_image(torch.rand(3, height, width,), torch.ones(height, width,) * 2, height, width)
prepare_mask_and_masked_image(
torch.rand(
3,
height,
width,
),
torch.ones(
height,
width,
)
* 2,
height,
width,
)
# test mask >= 0
with self.assertRaises(ValueError):
prepare_mask_and_masked_image(torch.rand(3, height, width,), torch.ones(height, width,) * -1, height, width)
prepare_mask_and_masked_image(
torch.rand(
3,
height,
width,
),
torch.ones(
height,
width,
)
* -1,
height,
width,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment