Unverified Commit 944ddce8 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Enable passing number of channels when inferring data format (#25412)

parent cb3c821c
...@@ -144,17 +144,24 @@ def to_numpy_array(img) -> np.ndarray: ...@@ -144,17 +144,24 @@ def to_numpy_array(img) -> np.ndarray:
return to_numpy(img) return to_numpy(img)
def infer_channel_dimension_format(image: np.ndarray) -> ChannelDimension: def infer_channel_dimension_format(
image: np.ndarray, num_channels: Optional[Union[int, Tuple[int, ...]]] = None
) -> ChannelDimension:
""" """
Infers the channel dimension format of `image`. Infers the channel dimension format of `image`.
Args: Args:
image (`np.ndarray`): image (`np.ndarray`):
The image to infer the channel dimension of. The image to infer the channel dimension of.
num_channels (`int` or `Tuple[int, ...]`, *optional*, defaults to `(1, 3)`):
The number of channels of the image.
Returns: Returns:
The channel dimension of the image. The channel dimension of the image.
""" """
num_channels = num_channels if num_channels is not None else (1, 3)
num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels
if image.ndim == 3: if image.ndim == 3:
first_dim, last_dim = 0, 2 first_dim, last_dim = 0, 2
elif image.ndim == 4: elif image.ndim == 4:
...@@ -162,9 +169,9 @@ def infer_channel_dimension_format(image: np.ndarray) -> ChannelDimension: ...@@ -162,9 +169,9 @@ def infer_channel_dimension_format(image: np.ndarray) -> ChannelDimension:
else: else:
raise ValueError(f"Unsupported number of image dimensions: {image.ndim}") raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
if image.shape[first_dim] in (1, 3): if image.shape[first_dim] in num_channels:
return ChannelDimension.FIRST return ChannelDimension.FIRST
elif image.shape[last_dim] in (1, 3): elif image.shape[last_dim] in num_channels:
return ChannelDimension.LAST return ChannelDimension.LAST
raise ValueError("Unable to infer channel dimension format") raise ValueError("Unable to infer channel dimension format")
......
...@@ -578,6 +578,10 @@ class UtilFunctionTester(unittest.TestCase): ...@@ -578,6 +578,10 @@ class UtilFunctionTester(unittest.TestCase):
with pytest.raises(ValueError): with pytest.raises(ValueError):
infer_channel_dimension_format(np.random.randint(0, 256, (10, 1, 50))) infer_channel_dimension_format(np.random.randint(0, 256, (10, 1, 50)))
# But if we explicitly set one of the number of channels to 50 it works
inferred_dim = infer_channel_dimension_format(np.random.randint(0, 256, (10, 1, 50)), num_channels=50)
self.assertEqual(inferred_dim, ChannelDimension.LAST)
# Test we correctly identify the channel dimension # Test we correctly identify the channel dimension
image = np.random.randint(0, 256, (3, 4, 5)) image = np.random.randint(0, 256, (3, 4, 5))
inferred_dim = infer_channel_dimension_format(image) inferred_dim = infer_channel_dimension_format(image)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment