"...text-generation-inference.git" did not exist on "f59fb8b630844c2ad2cd80e689202de89d45c37e"
Unverified Commit bf03f4ed authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

cleanup spatial_size -> canvas_size (#7783)

parent 9ebf10af
...@@ -423,7 +423,7 @@ DEFAULT_SPATIAL_SIZES = ( ...@@ -423,7 +423,7 @@ DEFAULT_SPATIAL_SIZES = (
) )
def _parse_canvas_size(size, *, name="size"): def _parse_size(size, *, name="size"):
if size == "random": if size == "random":
raise ValueError("This should never happen") raise ValueError("This should never happen")
elif isinstance(size, int) and size > 0: elif isinstance(size, int) and size > 0:
...@@ -478,13 +478,13 @@ class TensorLoader: ...@@ -478,13 +478,13 @@ class TensorLoader:
@dataclasses.dataclass @dataclasses.dataclass
class ImageLoader(TensorLoader): class ImageLoader(TensorLoader):
canvas_size: Tuple[int, int] = dataclasses.field(init=False) spatial_size: Tuple[int, int] = dataclasses.field(init=False)
num_channels: int = dataclasses.field(init=False) num_channels: int = dataclasses.field(init=False)
memory_format: torch.memory_format = torch.contiguous_format memory_format: torch.memory_format = torch.contiguous_format
canvas_size: Tuple[int, int] = dataclasses.field(init=False) canvas_size: Tuple[int, int] = dataclasses.field(init=False)
def __post_init__(self): def __post_init__(self):
self.canvas_size = self.canvas_size = self.shape[-2:] self.spatial_size = self.canvas_size = self.shape[-2:]
self.num_channels = self.shape[-3] self.num_channels = self.shape[-3]
def load(self, device): def load(self, device):
...@@ -550,7 +550,7 @@ def make_image_loader( ...@@ -550,7 +550,7 @@ def make_image_loader(
): ):
if not constant_alpha: if not constant_alpha:
raise ValueError("This should never happen") raise ValueError("This should never happen")
size = _parse_canvas_size(size) size = _parse_size(size)
num_channels = get_num_channels(color_space) num_channels = get_num_channels(color_space)
def fn(shape, dtype, device, memory_format): def fn(shape, dtype, device, memory_format):
...@@ -590,7 +590,7 @@ make_images = from_loaders(make_image_loaders) ...@@ -590,7 +590,7 @@ make_images = from_loaders(make_image_loaders)
def make_image_loader_for_interpolation( def make_image_loader_for_interpolation(
size=(233, 147), *, color_space="RGB", dtype=torch.uint8, memory_format=torch.contiguous_format size=(233, 147), *, color_space="RGB", dtype=torch.uint8, memory_format=torch.contiguous_format
): ):
size = _parse_canvas_size(size) size = _parse_size(size)
num_channels = get_num_channels(color_space) num_channels = get_num_channels(color_space)
def fn(shape, dtype, device, memory_format): def fn(shape, dtype, device, memory_format):
...@@ -687,11 +687,11 @@ def make_bounding_box( ...@@ -687,11 +687,11 @@ def make_bounding_box(
) )
def make_bounding_box_loader(*, extra_dims=(), format, canvas_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, dtype=torch.float32): def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, dtype=torch.float32):
if isinstance(format, str): if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format] format = datapoints.BoundingBoxFormat[format]
canvas_size = _parse_canvas_size(canvas_size, name="canvas_size") spatial_size = _parse_size(spatial_size, name="canvas_size")
def fn(shape, dtype, device): def fn(shape, dtype, device):
*batch_dims, num_coordinates = shape *batch_dims, num_coordinates = shape
...@@ -699,21 +699,21 @@ def make_bounding_box_loader(*, extra_dims=(), format, canvas_size=DEFAULT_PORTR ...@@ -699,21 +699,21 @@ def make_bounding_box_loader(*, extra_dims=(), format, canvas_size=DEFAULT_PORTR
raise pytest.UsageError() raise pytest.UsageError()
return make_bounding_box( return make_bounding_box(
format=format, canvas_size=canvas_size, batch_dims=batch_dims, dtype=dtype, device=device format=format, canvas_size=spatial_size, batch_dims=batch_dims, dtype=dtype, device=device
) )
return BoundingBoxesLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=canvas_size) return BoundingBoxesLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=spatial_size)
def make_bounding_box_loaders( def make_bounding_box_loaders(
*, *,
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=DEFAULT_EXTRA_DIMS,
formats=tuple(datapoints.BoundingBoxFormat), formats=tuple(datapoints.BoundingBoxFormat),
canvas_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
dtypes=(torch.float32, torch.float64, torch.int64), dtypes=(torch.float32, torch.float64, torch.int64),
): ):
for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes): for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes):
yield make_bounding_box_loader(**params, canvas_size=canvas_size) yield make_bounding_box_loader(**params, spatial_size=spatial_size)
make_bounding_boxes = from_loaders(make_bounding_box_loaders) make_bounding_boxes = from_loaders(make_bounding_box_loaders)
...@@ -738,7 +738,7 @@ def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtyp ...@@ -738,7 +738,7 @@ def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtyp
def make_detection_mask_loader(size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_objects=5, extra_dims=(), dtype=torch.uint8): def make_detection_mask_loader(size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_objects=5, extra_dims=(), dtype=torch.uint8):
# This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects # This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects
size = _parse_canvas_size(size) size = _parse_size(size)
def fn(shape, dtype, device): def fn(shape, dtype, device):
*batch_dims, num_objects, height, width = shape *batch_dims, num_objects, height, width = shape
...@@ -779,7 +779,7 @@ def make_segmentation_mask_loader( ...@@ -779,7 +779,7 @@ def make_segmentation_mask_loader(
size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_categories=10, extra_dims=(), dtype=torch.uint8 size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_categories=10, extra_dims=(), dtype=torch.uint8
): ):
# This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values # This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values
canvas_size = _parse_canvas_size(size) size = _parse_size(size)
def fn(shape, dtype, device): def fn(shape, dtype, device):
*batch_dims, height, width = shape *batch_dims, height, width = shape
...@@ -787,7 +787,7 @@ def make_segmentation_mask_loader( ...@@ -787,7 +787,7 @@ def make_segmentation_mask_loader(
(height, width), num_categories=num_categories, batch_dims=batch_dims, dtype=dtype, device=device (height, width), num_categories=num_categories, batch_dims=batch_dims, dtype=dtype, device=device
) )
return MaskLoader(fn, shape=(*extra_dims, *canvas_size), dtype=dtype) return MaskLoader(fn, shape=(*extra_dims, *size), dtype=dtype)
def make_segmentation_mask_loaders( def make_segmentation_mask_loaders(
...@@ -841,7 +841,7 @@ def make_video_loader( ...@@ -841,7 +841,7 @@ def make_video_loader(
extra_dims=(), extra_dims=(),
dtype=torch.uint8, dtype=torch.uint8,
): ):
size = _parse_canvas_size(size) size = _parse_size(size)
def fn(shape, dtype, device, memory_format): def fn(shape, dtype, device, memory_format):
*batch_dims, num_frames, _, height, width = shape *batch_dims, num_frames, _, height, width = shape
......
...@@ -884,7 +884,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints): ...@@ -884,7 +884,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
pcoeffs = _get_perspective_coeffs(startpoints, endpoints) pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints) inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints)
for bboxes in make_bounding_boxes(canvas_size=canvas_size, extra_dims=((4,),)): for bboxes in make_bounding_boxes(spatial_size=canvas_size, extra_dims=((4,),)):
bboxes = bboxes.to(device) bboxes = bboxes.to(device)
output_bboxes = F.perspective_bounding_boxes( output_bboxes = F.perspective_bounding_boxes(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment