"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "1ac07d8a8dd38fb155da48e5aafbfb63b3958520"
Commit 232dad22 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

improve error message in batch transform

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/507

Reviewed By: crassirostris

Differential Revision: D44269996

fbshipit-source-id: 91b313aeb820ec39e60c29c4c1bd9e669e1f7a6b
parent 8a21c18b
...@@ -82,7 +82,10 @@ class PadBorderDivisible(aug.Augmentation): ...@@ -82,7 +82,10 @@ class PadBorderDivisible(aug.Augmentation):
def get_transform(self, image: np.ndarray) -> Transform: def get_transform(self, image: np.ndarray) -> Transform:
"""image: HxWxC""" """image: HxWxC"""
assert len(image.shape) == 3 and image.shape[2] in [1, 3] assert len(image.shape) == 3 and image.shape[2] in [
1,
3,
], f"Invalid image shape {image.shape}"
H, W = image.shape[:2] H, W = image.shape[:2]
new_h = int(math.ceil(H / self.size_divisibility) * self.size_divisibility) new_h = int(math.ceil(H / self.size_divisibility) * self.size_divisibility)
new_w = int(math.ceil(W / self.size_divisibility) * self.size_divisibility) new_w = int(math.ceil(W / self.size_divisibility) * self.size_divisibility)
...@@ -103,7 +106,10 @@ class PadToSquare(aug.Augmentation): ...@@ -103,7 +106,10 @@ class PadToSquare(aug.Augmentation):
def get_transform(self, image: np.ndarray) -> Transform: def get_transform(self, image: np.ndarray) -> Transform:
"""image: HxWxC""" """image: HxWxC"""
assert len(image.shape) == 3 and image.shape[2] in [1, 3] assert len(image.shape) == 3 and image.shape[2] in [
1,
3,
], f"Invalid image shape {image.shape}"
H, W = image.shape[:2] H, W = image.shape[:2]
new_h = new_w = max(H, W) new_h = new_w = max(H, W)
return PadTransform( return PadTransform(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment