Unverified Commit e2bb6baa authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

[FBcode->GH] Fix accimage tests (#5545)



* Fix accimage tests

* Adding workaround for accimage

* Refactoring

* restore channels
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent ea197e4e
...@@ -23,7 +23,10 @@ def _is_pil_image(img: Any) -> bool: ...@@ -23,7 +23,10 @@ def _is_pil_image(img: Any) -> bool:
@torch.jit.unused @torch.jit.unused
def get_dimensions(img: Any) -> List[int]: def get_dimensions(img: Any) -> List[int]:
if _is_pil_image(img): if _is_pil_image(img):
if hasattr(img, "getbands"):
channels = len(img.getbands()) channels = len(img.getbands())
else:
channels = img.channels
width, height = img.size width, height = img.size
return [channels, height, width] return [channels, height, width]
raise TypeError(f"Unexpected type {type(img)}") raise TypeError(f"Unexpected type {type(img)}")
...@@ -39,7 +42,10 @@ def get_image_size(img: Any) -> List[int]: ...@@ -39,7 +42,10 @@ def get_image_size(img: Any) -> List[int]:
@torch.jit.unused @torch.jit.unused
def get_image_num_channels(img: Any) -> int: def get_image_num_channels(img: Any) -> int:
if _is_pil_image(img): if _is_pil_image(img):
if hasattr(img, "getbands"):
return len(img.getbands()) return len(img.getbands())
else:
return img.channels
raise TypeError(f"Unexpected type {type(img)}") raise TypeError(f"Unexpected type {type(img)}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment