Unverified Commit 289fce29 authored by Joao Gomes's avatar Joao Gomes Committed by GitHub
Browse files

Replace asserts with exceptions (#5587)



* replace most asserts with exceptions

* fix formating issues

* fix linting and remove more asserts

* fix regresion

* fix regresion

* fix bug

* apply ufmt

* apply ufmt

* fix tests

* fix format

* fix None check

* fix detection models tests

* non scriptable any

* add more checks for None values

* fix retinanet test

* fix retinanet test

* Update references/classification/transforms.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Update references/classification/transforms.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Update references/optical_flow/transforms.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Update references/optical_flow/transforms.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Update references/optical_flow/transforms.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* make value checks more pythonic:

* Update references/optical_flow/transforms.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* make value checks more pythonic

* make more checks pythonic

* fix bug

* appy ufmt

* fix tracing issues

* fib typos

* fix lint

* remove unecessary f-strings

* fix bug

* Update torchvision/datasets/mnist.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Update torchvision/datasets/mnist.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Update torchvision/ops/boxes.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Update torchvision/ops/poolers.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Update torchvision/utils.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* address PR comments

* Update torchvision/io/_video_opt.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Update torchvision/models/detection/generalized_rcnn.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Update torchvision/models/feature_extraction.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Update torchvision/models/optical_flow/raft.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* address PR comments

* addressing further pr comments

* fix bug

* remove unecessary else

* apply ufmt

* last pr comment

* replace RuntimeErrors
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 9bbb777d
...@@ -21,8 +21,14 @@ class RandomMixup(torch.nn.Module): ...@@ -21,8 +21,14 @@ class RandomMixup(torch.nn.Module):
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None: def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
super().__init__() super().__init__()
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
assert alpha > 0, "Alpha param can't be zero." if num_classes < 1:
raise ValueError(
f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}"
)
if alpha <= 0:
raise ValueError("Alpha param can't be zero.")
self.num_classes = num_classes self.num_classes = num_classes
self.p = p self.p = p
...@@ -99,8 +105,10 @@ class RandomCutmix(torch.nn.Module): ...@@ -99,8 +105,10 @@ class RandomCutmix(torch.nn.Module):
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None: def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
super().__init__() super().__init__()
assert num_classes > 0, "Please provide a valid positive value for the num_classes." if num_classes < 1:
assert alpha > 0, "Alpha param can't be zero." raise ValueError("Please provide a valid positive value for the num_classes.")
if alpha <= 0:
raise ValueError("Alpha param can't be zero.")
self.num_classes = num_classes self.num_classes = num_classes
self.p = p self.p = p
......
...@@ -12,7 +12,8 @@ from pycocotools.cocoeval import COCOeval ...@@ -12,7 +12,8 @@ from pycocotools.cocoeval import COCOeval
class CocoEvaluator: class CocoEvaluator:
def __init__(self, coco_gt, iou_types): def __init__(self, coco_gt, iou_types):
assert isinstance(iou_types, (list, tuple)) if not isinstance(iou_types, (list, tuple)):
raise TypeError(f"This constructor expects iou_types of type list or tuple, instead got {type(iou_types)}")
coco_gt = copy.deepcopy(coco_gt) coco_gt = copy.deepcopy(coco_gt)
self.coco_gt = coco_gt self.coco_gt = coco_gt
......
...@@ -126,7 +126,10 @@ def _coco_remove_images_without_annotations(dataset, cat_list=None): ...@@ -126,7 +126,10 @@ def _coco_remove_images_without_annotations(dataset, cat_list=None):
return True return True
return False return False
assert isinstance(dataset, torchvision.datasets.CocoDetection) if not isinstance(dataset, torchvision.datasets.CocoDetection):
raise TypeError(
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
)
ids = [] ids = []
for ds_idx, img_id in enumerate(dataset.ids): for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
......
...@@ -7,16 +7,21 @@ class ValidateModelInput(torch.nn.Module): ...@@ -7,16 +7,21 @@ class ValidateModelInput(torch.nn.Module):
# Pass-through transform that checks the shape and dtypes to make sure the model gets what it expects # Pass-through transform that checks the shape and dtypes to make sure the model gets what it expects
def forward(self, img1, img2, flow, valid_flow_mask): def forward(self, img1, img2, flow, valid_flow_mask):
assert all(isinstance(arg, torch.Tensor) for arg in (img1, img2, flow, valid_flow_mask) if arg is not None) if not all(isinstance(arg, torch.Tensor) for arg in (img1, img2, flow, valid_flow_mask) if arg is not None):
assert all(arg.dtype == torch.float32 for arg in (img1, img2, flow) if arg is not None) raise TypeError("This method expects all input arguments to be of type torch.Tensor.")
if not all(arg.dtype == torch.float32 for arg in (img1, img2, flow) if arg is not None):
raise TypeError("This method expects the tensors img1, img2 and flow of be of dtype torch.float32.")
assert img1.shape == img2.shape if img1.shape != img2.shape:
raise ValueError("img1 and img2 should have the same shape.")
h, w = img1.shape[-2:] h, w = img1.shape[-2:]
if flow is not None: if flow is not None and flow.shape != (2, h, w):
assert flow.shape == (2, h, w) raise ValueError(f"flow.shape should be (2, {h}, {w}) instead of {flow.shape}")
if valid_flow_mask is not None: if valid_flow_mask is not None:
assert valid_flow_mask.shape == (h, w) if valid_flow_mask.shape != (h, w):
assert valid_flow_mask.dtype == torch.bool raise ValueError(f"valid_flow_mask.shape should be ({h}, {w}) instead of {valid_flow_mask.shape}")
if valid_flow_mask.dtype != torch.bool:
raise TypeError("valid_flow_mask should be of dtype torch.bool instead of {valid_flow_mask.dtype}")
return img1, img2, flow, valid_flow_mask return img1, img2, flow, valid_flow_mask
...@@ -109,7 +114,8 @@ class RandomErasing(T.RandomErasing): ...@@ -109,7 +114,8 @@ class RandomErasing(T.RandomErasing):
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False, max_erase=1): def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False, max_erase=1):
super().__init__(p=p, scale=scale, ratio=ratio, value=value, inplace=inplace) super().__init__(p=p, scale=scale, ratio=ratio, value=value, inplace=inplace)
self.max_erase = max_erase self.max_erase = max_erase
assert self.max_erase > 0 if self.max_erase <= 0:
raise ValueError("max_raise should be greater than 0")
def forward(self, img1, img2, flow, valid_flow_mask): def forward(self, img1, img2, flow, valid_flow_mask):
if torch.rand(1) > self.p: if torch.rand(1) > self.p:
......
...@@ -71,7 +71,10 @@ class MetricLogger: ...@@ -71,7 +71,10 @@ class MetricLogger:
for k, v in kwargs.items(): for k, v in kwargs.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
v = v.item() v = v.item()
assert isinstance(v, (float, int)) if not isinstance(v, (float, int)):
raise TypeError(
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
)
self.meters[k].update(v) self.meters[k].update(v)
def __getattr__(self, attr): def __getattr__(self, attr):
......
...@@ -68,7 +68,11 @@ def _coco_remove_images_without_annotations(dataset, cat_list=None): ...@@ -68,7 +68,11 @@ def _coco_remove_images_without_annotations(dataset, cat_list=None):
# if more than 1k pixels occupied in the image # if more than 1k pixels occupied in the image
return sum(obj["area"] for obj in anno) > 1000 return sum(obj["area"] for obj in anno) > 1000
assert isinstance(dataset, torchvision.datasets.CocoDetection) if not isinstance(dataset, torchvision.datasets.CocoDetection):
raise TypeError(
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
)
ids = [] ids = []
for ds_idx, img_id in enumerate(dataset.ids): for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
......
...@@ -118,7 +118,10 @@ class MetricLogger: ...@@ -118,7 +118,10 @@ class MetricLogger:
for k, v in kwargs.items(): for k, v in kwargs.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
v = v.item() v = v.item()
assert isinstance(v, (float, int)) if not isinstance(v, (float, int)):
raise TypeError(
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
)
self.meters[k].update(v) self.meters[k].update(v)
def __getattr__(self, attr): def __getattr__(self, attr):
......
...@@ -47,7 +47,8 @@ class PKSampler(Sampler): ...@@ -47,7 +47,8 @@ class PKSampler(Sampler):
self.groups = create_groups(groups, self.k) self.groups = create_groups(groups, self.k)
# Ensures there are enough classes to sample from # Ensures there are enough classes to sample from
assert len(self.groups) >= p if len(self.groups) < p:
raise ValueError("There are not enought classes to sample from")
def __iter__(self): def __iter__(self):
# Shuffle samples within groups # Shuffle samples within groups
......
...@@ -76,7 +76,10 @@ class MetricLogger: ...@@ -76,7 +76,10 @@ class MetricLogger:
for k, v in kwargs.items(): for k, v in kwargs.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
v = v.item() v = v.item()
assert isinstance(v, (float, int)) if not isinstance(v, (float, int)):
raise TypeError(
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
)
self.meters[k].update(v) self.meters[k].update(v)
def __getattr__(self, attr): def __getattr__(self, attr):
......
...@@ -144,16 +144,16 @@ class TestFxFeatureExtraction: ...@@ -144,16 +144,16 @@ class TestFxFeatureExtraction:
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
) )
# Check must specify return nodes # Check must specify return nodes
with pytest.raises(AssertionError): with pytest.raises(ValueError):
self._create_feature_extractor(model) self._create_feature_extractor(model)
# Check return_nodes and train_return_nodes / eval_return nodes # Check return_nodes and train_return_nodes / eval_return nodes
# mutual exclusivity # mutual exclusivity
with pytest.raises(AssertionError): with pytest.raises(ValueError):
self._create_feature_extractor( self._create_feature_extractor(
model, return_nodes=train_return_nodes, train_return_nodes=train_return_nodes model, return_nodes=train_return_nodes, train_return_nodes=train_return_nodes
) )
# Check train_return_nodes / eval_return nodes must both be specified # Check train_return_nodes / eval_return nodes must both be specified
with pytest.raises(AssertionError): with pytest.raises(ValueError):
self._create_feature_extractor(model, train_return_nodes=train_return_nodes) self._create_feature_extractor(model, train_return_nodes=train_return_nodes)
# Check invalid node name raises ValueError # Check invalid node name raises ValueError
with pytest.raises(ValueError): with pytest.raises(ValueError):
......
...@@ -767,7 +767,7 @@ def test_detection_model_validation(model_fn): ...@@ -767,7 +767,7 @@ def test_detection_model_validation(model_fn):
# validate type # validate type
targets = [{"boxes": 0.0}] targets = [{"boxes": 0.0}]
with pytest.raises(ValueError): with pytest.raises(TypeError):
model(x, targets=targets) model(x, targets=targets)
# validate boxes shape # validate boxes shape
......
...@@ -138,13 +138,13 @@ class RoIOpTester(ABC): ...@@ -138,13 +138,13 @@ class RoIOpTester(ABC):
def _helper_boxes_shape(self, func): def _helper_boxes_shape(self, func):
# test boxes as Tensor[N, 5] # test boxes as Tensor[N, 5]
with pytest.raises(AssertionError): with pytest.raises(ValueError):
a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8) a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
boxes = torch.tensor([[0, 0, 3, 3]], dtype=a.dtype) boxes = torch.tensor([[0, 0, 3, 3]], dtype=a.dtype)
func(a, boxes, output_size=(2, 2)) func(a, boxes, output_size=(2, 2))
# test boxes as List[Tensor[N, 4]] # test boxes as List[Tensor[N, 4]]
with pytest.raises(AssertionError): with pytest.raises(ValueError):
a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8) a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype) boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype)
ops.roi_pool(a, [boxes], output_size=(2, 2)) ops.roi_pool(a, [boxes], output_size=(2, 2))
......
...@@ -118,7 +118,8 @@ class Kinetics(VisionDataset): ...@@ -118,7 +118,8 @@ class Kinetics(VisionDataset):
print("Using legacy structure") print("Using legacy structure")
self.split_folder = root self.split_folder = root
self.split = "unknown" self.split = "unknown"
assert not download, "Cannot download the videos using legacy_structure." if download:
raise ValueError("Cannot download the videos using legacy_structure.")
else: else:
self.split_folder = path.join(root, split) self.split_folder = path.join(root, split)
self.split = verify_str_arg(split, arg="split", valid_values=["train", "val"]) self.split = verify_str_arg(split, arg="split", valid_values=["train", "val"])
......
...@@ -442,11 +442,14 @@ class QMNIST(MNIST): ...@@ -442,11 +442,14 @@ class QMNIST(MNIST):
def _load_data(self): def _load_data(self):
data = read_sn3_pascalvincent_tensor(self.images_file) data = read_sn3_pascalvincent_tensor(self.images_file)
assert data.dtype == torch.uint8 if data.dtype != torch.uint8:
assert data.ndimension() == 3 raise TypeError(f"data should be of dtype torch.uint8 instead of {data.dtype}")
if data.ndimension() != 3:
raise ValueError("data should have 3 dimensions instead of {data.ndimension()}")
targets = read_sn3_pascalvincent_tensor(self.labels_file).long() targets = read_sn3_pascalvincent_tensor(self.labels_file).long()
assert targets.ndimension() == 2 if targets.ndimension() != 2:
raise ValueError(f"targets should have 2 dimensions instead of {targets.ndimension()}")
if self.what == "test10k": if self.what == "test10k":
data = data[0:10000, :, :].clone() data = data[0:10000, :, :].clone()
...@@ -530,13 +533,17 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso ...@@ -530,13 +533,17 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso
def read_label_file(path: str) -> torch.Tensor: def read_label_file(path: str) -> torch.Tensor:
x = read_sn3_pascalvincent_tensor(path, strict=False) x = read_sn3_pascalvincent_tensor(path, strict=False)
assert x.dtype == torch.uint8 if x.dtype != torch.uint8:
assert x.ndimension() == 1 raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
if x.ndimension() != 1:
raise ValueError(f"x should have 1 dimension instead of {x.ndimension()}")
return x.long() return x.long()
def read_image_file(path: str) -> torch.Tensor: def read_image_file(path: str) -> torch.Tensor:
x = read_sn3_pascalvincent_tensor(path, strict=False) x = read_sn3_pascalvincent_tensor(path, strict=False)
assert x.dtype == torch.uint8 if x.dtype != torch.uint8:
assert x.ndimension() == 3 raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
if x.ndimension() != 3:
raise ValueError(f"x should have 3 dimension instead of {x.ndimension()}")
return x return x
...@@ -52,12 +52,10 @@ class DistributedSampler(Sampler): ...@@ -52,12 +52,10 @@ class DistributedSampler(Sampler):
if not dist.is_available(): if not dist.is_available():
raise RuntimeError("Requires distributed package to be available") raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank() rank = dist.get_rank()
assert ( if len(dataset) % group_size != 0:
len(dataset) % group_size == 0 raise ValueError(
), "dataset length must be a multiplier of group size dataset length: %d, group size: %d" % ( f"dataset length must be a multiplier of group size dataset length: {len(dataset)}, group size: {group_size}"
len(dataset), )
group_size,
)
self.dataset = dataset self.dataset = dataset
self.group_size = group_size self.group_size = group_size
self.num_replicas = num_replicas self.num_replicas = num_replicas
......
...@@ -92,7 +92,6 @@ class SBDataset(VisionDataset): ...@@ -92,7 +92,6 @@ class SBDataset(VisionDataset):
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
self.masks = [os.path.join(mask_dir, x + ".mat") for x in file_names] self.masks = [os.path.join(mask_dir, x + ".mat") for x in file_names]
assert len(self.images) == len(self.masks)
self._get_target = self._get_segmentation_target if self.mode == "segmentation" else self._get_boundaries_target self._get_target = self._get_segmentation_target if self.mode == "segmentation" else self._get_boundaries_target
......
...@@ -38,7 +38,8 @@ def unfold(tensor: torch.Tensor, size: int, step: int, dilation: int = 1) -> tor ...@@ -38,7 +38,8 @@ def unfold(tensor: torch.Tensor, size: int, step: int, dilation: int = 1) -> tor
`step` between windows. The distance between each element `step` between windows. The distance between each element
in a window is given by `dilation`. in a window is given by `dilation`.
""" """
assert tensor.dim() == 1 if tensor.dim() != 1:
raise ValueError(f"tensor should have 1 dimension instead of {tensor.dim()}")
o_stride = tensor.stride(0) o_stride = tensor.stride(0)
numel = tensor.numel() numel = tensor.numel()
new_stride = (step * o_stride, dilation * o_stride) new_stride = (step * o_stride, dilation * o_stride)
......
...@@ -67,13 +67,9 @@ class VideoMetaData: ...@@ -67,13 +67,9 @@ class VideoMetaData:
def _validate_pts(pts_range: Tuple[int, int]) -> None: def _validate_pts(pts_range: Tuple[int, int]) -> None:
if pts_range[1] > 0: if pts_range[0] > pts_range[1] > 0:
assert ( raise ValueError(
pts_range[0] <= pts_range[1] f"Start pts should not be smaller than end pts, got start pts: {pts_range[0]} and end pts: {pts_range[1]}"
), """Start pts should not be smaller than end pts, got
start pts: {:d} and end pts: {:d}""".format(
pts_range[0],
pts_range[1],
) )
......
...@@ -159,8 +159,10 @@ class BoxCoder: ...@@ -159,8 +159,10 @@ class BoxCoder:
return targets return targets
def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor: def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
assert isinstance(boxes, (list, tuple)) if not isinstance(boxes, (list, tuple)):
assert isinstance(rel_codes, torch.Tensor) raise TypeError(f"This function expects boxes of type list or tuple, instead got {type(boxes)}")
if not isinstance(rel_codes, torch.Tensor):
raise TypeError(f"This function expects rel_codes of type torch.Tensor, instead got {type(rel_codes)}")
boxes_per_image = [b.size(0) for b in boxes] boxes_per_image = [b.size(0) for b in boxes]
concat_boxes = torch.cat(boxes, dim=0) concat_boxes = torch.cat(boxes, dim=0)
box_sum = 0 box_sum = 0
...@@ -333,7 +335,8 @@ class Matcher: ...@@ -333,7 +335,8 @@ class Matcher:
""" """
self.BELOW_LOW_THRESHOLD = -1 self.BELOW_LOW_THRESHOLD = -1
self.BETWEEN_THRESHOLDS = -2 self.BETWEEN_THRESHOLDS = -2
assert low_threshold <= high_threshold if low_threshold > high_threshold:
raise ValueError("low_threshold should be <= high_threshold")
self.high_threshold = high_threshold self.high_threshold = high_threshold
self.low_threshold = low_threshold self.low_threshold = low_threshold
self.allow_low_quality_matches = allow_low_quality_matches self.allow_low_quality_matches = allow_low_quality_matches
...@@ -371,7 +374,8 @@ class Matcher: ...@@ -371,7 +374,8 @@ class Matcher:
matches[between_thresholds] = self.BETWEEN_THRESHOLDS matches[between_thresholds] = self.BETWEEN_THRESHOLDS
if self.allow_low_quality_matches: if self.allow_low_quality_matches:
assert all_matches is not None if all_matches is None:
raise ValueError("all_matches should not be None")
self.set_low_quality_matches_(matches, all_matches, match_quality_matrix) self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
return matches return matches
......
...@@ -45,8 +45,6 @@ class AnchorGenerator(nn.Module): ...@@ -45,8 +45,6 @@ class AnchorGenerator(nn.Module):
if not isinstance(aspect_ratios[0], (list, tuple)): if not isinstance(aspect_ratios[0], (list, tuple)):
aspect_ratios = (aspect_ratios,) * len(sizes) aspect_ratios = (aspect_ratios,) * len(sizes)
assert len(sizes) == len(aspect_ratios)
self.sizes = sizes self.sizes = sizes
self.aspect_ratios = aspect_ratios self.aspect_ratios = aspect_ratios
self.cell_anchors = [ self.cell_anchors = [
...@@ -86,7 +84,9 @@ class AnchorGenerator(nn.Module): ...@@ -86,7 +84,9 @@ class AnchorGenerator(nn.Module):
def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]: def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]:
anchors = [] anchors = []
cell_anchors = self.cell_anchors cell_anchors = self.cell_anchors
assert cell_anchors is not None
if cell_anchors is None:
ValueError("cell_anchors should not be None")
if not (len(grid_sizes) == len(strides) == len(cell_anchors)): if not (len(grid_sizes) == len(strides) == len(cell_anchors)):
raise ValueError( raise ValueError(
...@@ -164,8 +164,8 @@ class DefaultBoxGenerator(nn.Module): ...@@ -164,8 +164,8 @@ class DefaultBoxGenerator(nn.Module):
clip: bool = True, clip: bool = True,
): ):
super().__init__() super().__init__()
if steps is not None: if steps is not None and len(aspect_ratios) != len(steps):
assert len(aspect_ratios) == len(steps) raise ValueError("aspect_ratios and steps should have the same length")
self.aspect_ratios = aspect_ratios self.aspect_ratios = aspect_ratios
self.steps = steps self.steps = steps
self.clip = clip self.clip = clip
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment