"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "edd7880418dd36da26ee434f556756f7168100c5"
Unverified Commit e0fd033c authored by ahmadsharif1's avatar ahmadsharif1 Committed by GitHub
Browse files

Expand the channels to 3 if the user requested as such (#8229)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 7f55a1b3
......@@ -4935,15 +4935,24 @@ class TestRgbToGrayscale:
check_transform(transform, make_input())
@pytest.mark.parametrize("num_output_channels", [1, 3])
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
@pytest.mark.parametrize("fn", [F.rgb_to_grayscale, transform_cls_to_functional(transforms.Grayscale)])
def test_image_correctness(self, num_output_channels, fn):
image = make_image(dtype=torch.uint8, device="cpu")
def test_image_correctness(self, num_output_channels, color_space, fn):
image = make_image(dtype=torch.uint8, device="cpu", color_space=color_space)
actual = fn(image, num_output_channels=num_output_channels)
expected = F.to_image(F.rgb_to_grayscale(F.to_pil_image(image), num_output_channels=num_output_channels))
assert_equal(actual, expected, rtol=0, atol=1)
def test_expanded_channels_are_not_views_into_the_same_underlying_tensor(self):
image = make_image(dtype=torch.uint8, device="cpu", color_space="GRAY")
output_image = F.rgb_to_grayscale(image, num_output_channels=3)
assert_equal(output_image[0][0][0], output_image[1][0][0])
output_image[0][0][0] = output_image[0][0][0] + 1
assert output_image[0][0][0] != output_image[1][0][0]
@pytest.mark.parametrize("num_input_channels", [1, 3])
def test_random_transform_correctness(self, num_input_channels):
image = make_image(
......
......@@ -33,9 +33,13 @@ to_grayscale = rgb_to_grayscale
def _rgb_to_grayscale_image(
image: torch.Tensor, num_output_channels: int = 1, preserve_dtype: bool = True
) -> torch.Tensor:
if image.shape[-3] == 1:
# TODO: Maybe move the validation that num_output_channels is 1 or 3 to this function instead of callers.
if image.shape[-3] == 1 and num_output_channels == 1:
return image.clone()
if image.shape[-3] == 1 and num_output_channels == 3:
s = [1] * len(image.shape)
s[-3] = 3
return image.repeat(s)
r, g, b = image.unbind(dim=-3)
l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114)
l_img = l_img.unsqueeze(dim=-3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment