Unverified Commit 289fce29 authored by Joao Gomes's avatar Joao Gomes Committed by GitHub
Browse files

Replace asserts with exceptions (#5587)



* replace most asserts with exceptions

* fix formating issues

* fix linting and remove more asserts

* fix regresion

* fix regresion

* fix bug

* apply ufmt

* apply ufmt

* fix tests

* fix format

* fix None check

* fix detection models tests

* non scriptable any

* add more checks for None values

* fix retinanet test

* fix retinanet test

* Update references/classification/transforms.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Update references/classification/transforms.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Update references/optical_flow/transforms.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Update references/optical_flow/transforms.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Update references/optical_flow/transforms.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* make value checks more pythonic:

* Update references/optical_flow/transforms.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* make value checks more pythonic

* make more checks pythonic

* fix bug

* appy ufmt

* fix tracing issues

* fib typos

* fix lint

* remove unecessary f-strings

* fix bug

* Update torchvision/datasets/mnist.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Update torchvision/datasets/mnist.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Update torchvision/ops/boxes.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Update torchvision/ops/poolers.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Update torchvision/utils.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* address PR comments

* Update torchvision/io/_video_opt.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Update torchvision/models/detection/generalized_rcnn.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Update torchvision/models/feature_extraction.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Update torchvision/models/optical_flow/raft.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* address PR comments

* addressing further pr comments

* fix bug

* remove unecessary else

* apply ufmt

* last pr comment

* replace RuntimeErrors
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 9bbb777d
......@@ -300,8 +300,10 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
_log_api_usage_once(generalized_box_iou)
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
if (boxes1[:, 2:] < boxes1[:, :2]).any():
raise ValueError("Some of the input boxes1 are invalid.")
if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
raise ValueError("Some of the input boxes2 are invalid.")
inter, union = _box_inter_union(boxes1, boxes2)
iou = inter / union
......
......@@ -95,7 +95,8 @@ class FeaturePyramidNetwork(nn.Module):
nn.init.constant_(m.bias, 0)
if extra_blocks is not None:
assert isinstance(extra_blocks, ExtraFPNBlock)
if not isinstance(extra_blocks, ExtraFPNBlock):
raise TypeError(f"extra_blocks should be of type ExtraFPNBlock not {type(extra_blocks)}")
self.extra_blocks = extra_blocks
def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor:
......
......@@ -104,7 +104,6 @@ def _infer_scale(feature: Tensor, original_size: List[int]) -> float:
approx_scale = float(s1) / float(s2)
scale = 2 ** float(torch.tensor(approx_scale).log2().round())
possible_scales.append(scale)
assert possible_scales[0] == possible_scales[1]
return possible_scales[0]
......@@ -112,7 +111,8 @@ def _infer_scale(feature: Tensor, original_size: List[int]) -> float:
def _setup_scales(
features: List[Tensor], image_shapes: List[Tuple[int, int]], canonical_scale: int, canonical_level: int
) -> Tuple[List[float], LevelMapper]:
assert len(image_shapes) != 0
if not image_shapes:
raise ValueError("images list should not be empty")
max_x = 0
max_y = 0
for shape in image_shapes:
......@@ -166,8 +166,8 @@ def _multiscale_roi_align(
Returns:
result (Tensor)
"""
assert scales is not None
assert mapper is not None
if scales is None or mapper is None:
raise ValueError("scales and mapper should not be None")
num_levels = len(x_filtered)
rois = _convert_to_roi_format(boxes)
......
......@@ -98,7 +98,10 @@ def ssdlite320_mobilenet_v3_large(
anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95)
out_channels = det_utils.retrieve_out_channels(backbone, size)
num_anchors = anchor_generator.num_anchors_per_location()
assert len(out_channels) == len(anchor_generator.aspect_ratios)
if len(out_channels) != len(anchor_generator.aspect_ratios):
raise ValueError(
f"The length of the output channels from the backbone {len(out_channels)} do not match the length of the anchor generator aspect ratios {len(anchor_generator.aspect_ratios)}"
)
defaults = {
"score_thresh": 0.001,
......
......@@ -24,12 +24,14 @@ def crop(clip, i, j, h, w):
Args:
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
"""
assert len(clip.size()) == 4, "clip should be a 4D tensor"
if len(clip.size()) != 4:
raise ValueError("clip should be a 4D tensor")
return clip[..., i : i + h, j : j + w]
def resize(clip, target_size, interpolation_mode):
assert len(target_size) == 2, "target size should be tuple (height, width)"
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
......@@ -46,17 +48,20 @@ def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
Returns:
clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W)
"""
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
clip = crop(clip, i, j, h, w)
clip = resize(clip, size, interpolation_mode)
return clip
def center_crop(clip, crop_size):
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
th, tw = crop_size
assert h >= th and w >= tw, "height and width must be no smaller than crop_size"
if h < th or w < tw:
raise ValueError("height and width must be no smaller than crop_size")
i = int(round((h - th) / 2.0))
j = int(round((w - tw) / 2.0))
......@@ -87,7 +92,8 @@ def normalize(clip, mean, std, inplace=False):
Returns:
normalized clip (torch.tensor): Size is (C, T, H, W)
"""
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
if not inplace:
clip = clip.clone()
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
......@@ -103,5 +109,6 @@ def hflip(clip):
Returns:
flipped clip (torch.tensor): Size is (C, T, H, W)
"""
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
return clip.flip(-1)
......@@ -59,7 +59,8 @@ class RandomResizedCropVideo(RandomResizedCrop):
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
assert len(size) == 2, "size should be tuple (height, width)"
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
......
......@@ -82,10 +82,8 @@ def make_grid(
if normalize is True:
tensor = tensor.clone() # avoid modifying tensor in-place
if value_range is not None:
assert isinstance(
value_range, tuple
), "value_range has to be a tuple (min, max) if specified. min and max are numbers"
if value_range is not None and not isinstance(value_range, tuple):
raise TypeError("value_range has to be a tuple (min, max) if specified. min and max are numbers")
def norm_ip(img, low, high):
img.clamp_(min=low, max=high)
......@@ -103,7 +101,8 @@ def make_grid(
else:
norm_range(tensor, value_range)
assert isinstance(tensor, torch.Tensor)
if not isinstance(tensor, torch.Tensor):
raise TypeError("tensor should be of type torch.Tensor")
if tensor.size(0) == 1:
return tensor.squeeze(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment