Unverified Commit bf03f4ed authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

cleanup spatial_size -> canvas_size (#7783)

parent 9ebf10af
......@@ -423,7 +423,7 @@ DEFAULT_SPATIAL_SIZES = (
)
def _parse_canvas_size(size, *, name="size"):
def _parse_size(size, *, name="size"):
if size == "random":
raise ValueError("This should never happen")
elif isinstance(size, int) and size > 0:
......@@ -478,13 +478,13 @@ class TensorLoader:
@dataclasses.dataclass
class ImageLoader(TensorLoader):
canvas_size: Tuple[int, int] = dataclasses.field(init=False)
spatial_size: Tuple[int, int] = dataclasses.field(init=False)
num_channels: int = dataclasses.field(init=False)
memory_format: torch.memory_format = torch.contiguous_format
canvas_size: Tuple[int, int] = dataclasses.field(init=False)
def __post_init__(self):
self.canvas_size = self.canvas_size = self.shape[-2:]
self.spatial_size = self.canvas_size = self.shape[-2:]
self.num_channels = self.shape[-3]
def load(self, device):
......@@ -550,7 +550,7 @@ def make_image_loader(
):
if not constant_alpha:
raise ValueError("This should never happen")
size = _parse_canvas_size(size)
size = _parse_size(size)
num_channels = get_num_channels(color_space)
def fn(shape, dtype, device, memory_format):
......@@ -590,7 +590,7 @@ make_images = from_loaders(make_image_loaders)
def make_image_loader_for_interpolation(
size=(233, 147), *, color_space="RGB", dtype=torch.uint8, memory_format=torch.contiguous_format
):
size = _parse_canvas_size(size)
size = _parse_size(size)
num_channels = get_num_channels(color_space)
def fn(shape, dtype, device, memory_format):
......@@ -687,11 +687,11 @@ def make_bounding_box(
)
def make_bounding_box_loader(*, extra_dims=(), format, canvas_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, dtype=torch.float32):
def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, dtype=torch.float32):
if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format]
canvas_size = _parse_canvas_size(canvas_size, name="canvas_size")
spatial_size = _parse_size(spatial_size, name="canvas_size")
def fn(shape, dtype, device):
*batch_dims, num_coordinates = shape
......@@ -699,21 +699,21 @@ def make_bounding_box_loader(*, extra_dims=(), format, canvas_size=DEFAULT_PORTR
raise pytest.UsageError()
return make_bounding_box(
format=format, canvas_size=canvas_size, batch_dims=batch_dims, dtype=dtype, device=device
format=format, canvas_size=spatial_size, batch_dims=batch_dims, dtype=dtype, device=device
)
return BoundingBoxesLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=canvas_size)
return BoundingBoxesLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=spatial_size)
def make_bounding_box_loaders(
*,
extra_dims=DEFAULT_EXTRA_DIMS,
formats=tuple(datapoints.BoundingBoxFormat),
canvas_size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
dtypes=(torch.float32, torch.float64, torch.int64),
):
for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes):
yield make_bounding_box_loader(**params, canvas_size=canvas_size)
yield make_bounding_box_loader(**params, spatial_size=spatial_size)
make_bounding_boxes = from_loaders(make_bounding_box_loaders)
......@@ -738,7 +738,7 @@ def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtyp
def make_detection_mask_loader(size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_objects=5, extra_dims=(), dtype=torch.uint8):
# This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects
size = _parse_canvas_size(size)
size = _parse_size(size)
def fn(shape, dtype, device):
*batch_dims, num_objects, height, width = shape
......@@ -779,7 +779,7 @@ def make_segmentation_mask_loader(
size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_categories=10, extra_dims=(), dtype=torch.uint8
):
# This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values
canvas_size = _parse_canvas_size(size)
size = _parse_size(size)
def fn(shape, dtype, device):
*batch_dims, height, width = shape
......@@ -787,7 +787,7 @@ def make_segmentation_mask_loader(
(height, width), num_categories=num_categories, batch_dims=batch_dims, dtype=dtype, device=device
)
return MaskLoader(fn, shape=(*extra_dims, *canvas_size), dtype=dtype)
return MaskLoader(fn, shape=(*extra_dims, *size), dtype=dtype)
def make_segmentation_mask_loaders(
......@@ -841,7 +841,7 @@ def make_video_loader(
extra_dims=(),
dtype=torch.uint8,
):
size = _parse_canvas_size(size)
size = _parse_size(size)
def fn(shape, dtype, device, memory_format):
*batch_dims, num_frames, _, height, width = shape
......
......@@ -884,7 +884,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints)
for bboxes in make_bounding_boxes(canvas_size=canvas_size, extra_dims=((4,),)):
for bboxes in make_bounding_boxes(spatial_size=canvas_size, extra_dims=((4,),)):
bboxes = bboxes.to(device)
output_bboxes = F.perspective_bounding_boxes(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment