"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "38bed912e36e1725ede1f0e8c61a514f378697c3"
Unverified Commit 55ba3190 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Add param_name to size_dict logs & tidy (#20205)

parent f1e8c48c
...@@ -440,11 +440,50 @@ class BaseImageProcessor(ImageProcessingMixin): ...@@ -440,11 +440,50 @@ class BaseImageProcessor(ImageProcessingMixin):
raise NotImplementedError("Each image processor must implement its own preprocess method") raise NotImplementedError("Each image processor must implement its own preprocess method")
VALID_SIZE_DICT_KEYS = ({"height", "width"}, {"shortest_edge"}, {"shortest_edge", "longest_edge"})
def is_valid_size_dict(size_dict):
if not isinstance(size_dict, dict):
return False
size_dict_keys = set(size_dict.keys())
for allowed_keys in VALID_SIZE_DICT_KEYS:
if size_dict_keys == allowed_keys:
return True
return False
def convert_to_size_dict(
size, max_size: Optional[int] = None, default_to_square: bool = True, height_width_order: bool = True
):
# By default, if size is an int we assume it represents a tuple of (size, size).
if isinstance(size, int) and default_to_square:
if max_size is not None:
raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size")
return {"height": size, "width": size}
# In other configs, if size is an int and default_to_square is False, size represents the length of
# the shortest edge after resizing.
elif isinstance(size, int) and not default_to_square:
size_dict = {"shortest_edge": size}
if max_size is not None:
size_dict["longest_edge"] = max_size
return size_dict
# Otherwise, if size is a tuple it's either (height, width) or (width, height)
elif isinstance(size, (tuple, list)) and height_width_order:
return {"height": size[0], "width": size[1]}
elif isinstance(size, (tuple, list)) and not height_width_order:
return {"height": size[1], "width": size[0]}
raise ValueError(f"Could not convert size input to size dict: {size}")
def get_size_dict( def get_size_dict(
size: Union[int, Iterable[int], Dict[str, int]] = None, size: Union[int, Iterable[int], Dict[str, int]] = None,
max_size: Optional[int] = None, max_size: Optional[int] = None,
height_width_order: bool = True, height_width_order: bool = True,
default_to_square: bool = True, default_to_square: bool = True,
param_name="size",
) -> dict: ) -> dict:
""" """
Converts the old size parameter in the config into the new dict expected in the config. This is to ensure backwards Converts the old size parameter in the config into the new dict expected in the config. This is to ensure backwards
...@@ -467,40 +506,19 @@ def get_size_dict( ...@@ -467,40 +506,19 @@ def get_size_dict(
default_to_square (`bool`, *optional*, defaults to `True`): default_to_square (`bool`, *optional*, defaults to `True`):
If `size` is an int, whether to default to a square image or not. If `size` is an int, whether to default to a square image or not.
""" """
# If a dict is passed, we check if it's a valid size dict and then return it. if not isinstance(size, dict):
if isinstance(size, dict): size_dict = convert_to_size_dict(size, max_size, default_to_square, height_width_order)
size_keys = set(size.keys()) logger.info(
if ( "{param_name} should be a dictionary on of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size}."
size_keys != set(["height", "width"]) " Converted to {size_dict}.",
and size_keys != set(["shortest_edge"]) )
and size_keys != set(["shortest_edge", "longest_edge"]) else:
): size_dict = size
raise ValueError(
"The size dict must contain either the keys ('height', 'width') or ('shortest_edge')"
f"or ('shortest_edge', 'longest_edge') but got {size_keys}"
)
return size
# By default, if size is an int we assume it represents a tuple of (size, size).
elif isinstance(size, int) and default_to_square:
if max_size is not None:
raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size")
size_dict = {"height": size, "width": size}
# In other configs, if size is an int and default_to_square is False, size represents the length of the shortest edge after resizing.
elif isinstance(size, int) and not default_to_square:
if max_size is not None:
size_dict = {"shortest_edge": size, "longest_edge": max_size}
else:
size_dict = {"shortest_edge": size}
elif isinstance(size, (tuple, list)) and height_width_order:
size_dict = {"height": size[0], "width": size[1]}
elif isinstance(size, (tuple, list)) and not height_width_order:
size_dict = {"height": size[1], "width": size[0]}
logger.info( if not is_valid_size_dict(size_dict):
"The size parameter should be a dictionary with keys ('height', 'width'), ('shortest_edge', 'longest_edge')" raise ValueError(
f" or ('shortest_edge',) got {size}. Setting as {size_dict}.", f"{param_name} must have one of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size_dict.keys()}"
) )
return size_dict return size_dict
......
...@@ -118,7 +118,7 @@ class BeitImageProcessor(BaseImageProcessor): ...@@ -118,7 +118,7 @@ class BeitImageProcessor(BaseImageProcessor):
size = size if size is not None else {"height": 256, "width": 256} size = size if size is not None else {"height": 256, "width": 256}
size = get_size_dict(size) size = get_size_dict(size)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size) crop_size = get_size_dict(crop_size, param_name="crop_size")
self.do_resize = do_resize self.do_resize = do_resize
self.size = size self.size = size
self.resample = resample self.resample = resample
...@@ -152,7 +152,7 @@ class BeitImageProcessor(BaseImageProcessor): ...@@ -152,7 +152,7 @@ class BeitImageProcessor(BaseImageProcessor):
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
""" """
size = get_size_dict(size) size = get_size_dict(size, default_to_square=True, param_name="size")
if "height" not in size or "width" not in size: if "height" not in size or "width" not in size:
raise ValueError(f"The `size` argument must contain `height` and `width` keys. Got {size.keys()}") raise ValueError(f"The `size` argument must contain `height` and `width` keys. Got {size.keys()}")
return resize( return resize(
...@@ -178,7 +178,7 @@ class BeitImageProcessor(BaseImageProcessor): ...@@ -178,7 +178,7 @@ class BeitImageProcessor(BaseImageProcessor):
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
""" """
size = get_size_dict(size) size = get_size_dict(size, default_to_square=True, param_name="size")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs) return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale( def rescale(
...@@ -406,11 +406,11 @@ class BeitImageProcessor(BaseImageProcessor): ...@@ -406,11 +406,11 @@ class BeitImageProcessor(BaseImageProcessor):
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
size = get_size_dict(size) size = get_size_dict(size, default_to_square=True, param_name="size")
resample = resample if resample is not None else self.resample resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size) crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
do_rescale = do_rescale if do_rescale is not None else self.do_rescale do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize do_normalize = do_normalize if do_normalize is not None else self.do_normalize
......
...@@ -114,7 +114,7 @@ class CLIPImageProcessor(BaseImageProcessor): ...@@ -114,7 +114,7 @@ class CLIPImageProcessor(BaseImageProcessor):
size = size if size is not None else {"shortest_edge": 224} size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size) crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
self.do_resize = do_resize self.do_resize = do_resize
self.size = size self.size = size
...@@ -176,6 +176,8 @@ class CLIPImageProcessor(BaseImageProcessor): ...@@ -176,6 +176,8 @@ class CLIPImageProcessor(BaseImageProcessor):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
""" """
size = get_size_dict(size) size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` parameter must contain the keys (height, width). Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs) return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale( def rescale(
...@@ -285,11 +287,11 @@ class CLIPImageProcessor(BaseImageProcessor): ...@@ -285,11 +287,11 @@ class CLIPImageProcessor(BaseImageProcessor):
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, param_name="size", default_to_square=False)
resample = resample if resample is not None else self.resample resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size) crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
do_rescale = do_rescale if do_rescale is not None else self.do_rescale do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize do_normalize = do_normalize if do_normalize is not None else self.do_normalize
......
...@@ -97,7 +97,7 @@ class DeiTImageProcessor(BaseImageProcessor): ...@@ -97,7 +97,7 @@ class DeiTImageProcessor(BaseImageProcessor):
size = size if size is not None else {"height": 256, "width": 256} size = size if size is not None else {"height": 256, "width": 256}
size = get_size_dict(size) size = get_size_dict(size)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size) crop_size = get_size_dict(crop_size, param_name="crop_size")
self.do_resize = do_resize self.do_resize = do_resize
self.size = size self.size = size
...@@ -158,6 +158,8 @@ class DeiTImageProcessor(BaseImageProcessor): ...@@ -158,6 +158,8 @@ class DeiTImageProcessor(BaseImageProcessor):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
""" """
size = get_size_dict(size) size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs) return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale( def rescale(
...@@ -272,7 +274,7 @@ class DeiTImageProcessor(BaseImageProcessor): ...@@ -272,7 +274,7 @@ class DeiTImageProcessor(BaseImageProcessor):
size = size if size is not None else self.size size = size if size is not None else self.size
size = get_size_dict(size) size = get_size_dict(size)
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size) crop_size = get_size_dict(crop_size, param_name="crop_size")
if not is_batched(images): if not is_batched(images):
images = [images] images = [images]
......
...@@ -253,12 +253,12 @@ class FlavaImageProcessor(BaseImageProcessor): ...@@ -253,12 +253,12 @@ class FlavaImageProcessor(BaseImageProcessor):
size = size if size is not None else {"height": 224, "width": 224} size = size if size is not None else {"height": 224, "width": 224}
size = get_size_dict(size) size = get_size_dict(size)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size) crop_size = get_size_dict(crop_size, param_name="crop_size")
codebook_size = codebook_size if codebook_size is not None else {"height": 112, "width": 112} codebook_size = codebook_size if codebook_size is not None else {"height": 112, "width": 112}
codebook_size = get_size_dict(codebook_size) codebook_size = get_size_dict(codebook_size, param_name="codebook_size")
codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else {"height": 112, "width": 112} codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else {"height": 112, "width": 112}
codebook_crop_size = get_size_dict(codebook_crop_size) codebook_crop_size = get_size_dict(codebook_crop_size, param_name="codebook_crop_size")
self.do_resize = do_resize self.do_resize = do_resize
self.size = size self.size = size
...@@ -360,6 +360,8 @@ class FlavaImageProcessor(BaseImageProcessor): ...@@ -360,6 +360,8 @@ class FlavaImageProcessor(BaseImageProcessor):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
""" """
size = get_size_dict(size) size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The size dictionary must contain 'height' and 'width' keys. Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs) return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale( def rescale(
...@@ -580,7 +582,7 @@ class FlavaImageProcessor(BaseImageProcessor): ...@@ -580,7 +582,7 @@ class FlavaImageProcessor(BaseImageProcessor):
resample = resample if resample is not None else self.resample resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size) crop_size = get_size_dict(crop_size, param_name="crop_size")
do_rescale = do_rescale if do_rescale is not None else self.do_rescale do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize do_normalize = do_normalize if do_normalize is not None else self.do_normalize
...@@ -612,7 +614,7 @@ class FlavaImageProcessor(BaseImageProcessor): ...@@ -612,7 +614,7 @@ class FlavaImageProcessor(BaseImageProcessor):
) )
codebook_do_resize = codebook_do_resize if codebook_do_resize is not None else self.codebook_do_resize codebook_do_resize = codebook_do_resize if codebook_do_resize is not None else self.codebook_do_resize
codebook_size = codebook_size if codebook_size is not None else self.codebook_size codebook_size = codebook_size if codebook_size is not None else self.codebook_size
codebook_size = get_size_dict(codebook_size) codebook_size = get_size_dict(codebook_size, param_name="codebook_size")
codebook_resample = codebook_resample if codebook_resample is not None else self.codebook_resample codebook_resample = codebook_resample if codebook_resample is not None else self.codebook_resample
codebook_do_rescale = codebook_do_rescale if codebook_do_rescale is not None else self.codebook_do_rescale codebook_do_rescale = codebook_do_rescale if codebook_do_rescale is not None else self.codebook_do_rescale
codebook_rescale_factor = ( codebook_rescale_factor = (
...@@ -622,7 +624,7 @@ class FlavaImageProcessor(BaseImageProcessor): ...@@ -622,7 +624,7 @@ class FlavaImageProcessor(BaseImageProcessor):
codebook_do_center_crop if codebook_do_center_crop is not None else self.codebook_do_center_crop codebook_do_center_crop if codebook_do_center_crop is not None else self.codebook_do_center_crop
) )
codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else self.codebook_crop_size codebook_crop_size = codebook_crop_size if codebook_crop_size is not None else self.codebook_crop_size
codebook_crop_size = get_size_dict(codebook_crop_size) codebook_crop_size = get_size_dict(codebook_crop_size, param_name="codebook_crop_size")
codebook_do_map_pixels = ( codebook_do_map_pixels = (
codebook_do_map_pixels if codebook_do_map_pixels is not None else self.codebook_do_map_pixels codebook_do_map_pixels if codebook_do_map_pixels is not None else self.codebook_do_map_pixels
) )
......
...@@ -105,7 +105,7 @@ class LevitImageProcessor(BaseImageProcessor): ...@@ -105,7 +105,7 @@ class LevitImageProcessor(BaseImageProcessor):
size = size if size is not None else {"shortest_edge": 224} size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size) crop_size = get_size_dict(crop_size, param_name="crop_size")
self.do_resize = do_resize self.do_resize = do_resize
self.size = size self.size = size
...@@ -182,6 +182,8 @@ class LevitImageProcessor(BaseImageProcessor): ...@@ -182,6 +182,8 @@ class LevitImageProcessor(BaseImageProcessor):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
""" """
size = get_size_dict(size) size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"Size dict must have keys 'height' and 'width'. Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs) return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale( def rescale(
...@@ -299,7 +301,7 @@ class LevitImageProcessor(BaseImageProcessor): ...@@ -299,7 +301,7 @@ class LevitImageProcessor(BaseImageProcessor):
size = size if size is not None else self.size size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size) crop_size = get_size_dict(crop_size, param_name="crop_size")
if not is_batched(images): if not is_batched(images):
images = [images] images = [images]
......
...@@ -109,7 +109,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor): ...@@ -109,7 +109,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
size = size if size is not None else {"shortest_edge": 256} size = size if size is not None else {"shortest_edge": 256}
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size) crop_size = get_size_dict(crop_size, param_name="crop_size")
self.do_resize = do_resize self.do_resize = do_resize
self.size = size self.size = size
self.resample = resample self.resample = resample
...@@ -169,6 +169,8 @@ class MobileNetV2ImageProcessor(BaseImageProcessor): ...@@ -169,6 +169,8 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
""" """
size = get_size_dict(size) size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` parameter must contain the keys `height` and `width`. Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs) return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale( def rescale(
...@@ -286,7 +288,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor): ...@@ -286,7 +288,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
resample = resample if resample is not None else self.resample resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size) crop_size = get_size_dict(crop_size, param_name="crop_size")
do_rescale = do_rescale if do_rescale is not None else self.do_rescale do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize do_normalize = do_normalize if do_normalize is not None else self.do_normalize
......
...@@ -123,7 +123,7 @@ class MobileViTImageProcessor(BaseImageProcessor): ...@@ -123,7 +123,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
size = size if size is not None else {"shortest_edge": 224} size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256} crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256}
crop_size = get_size_dict(crop_size) crop_size = get_size_dict(crop_size, param_name="crop_size")
self.do_resize = do_resize self.do_resize = do_resize
self.size = size self.size = size
...@@ -182,6 +182,8 @@ class MobileViTImageProcessor(BaseImageProcessor): ...@@ -182,6 +182,8 @@ class MobileViTImageProcessor(BaseImageProcessor):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
""" """
size = get_size_dict(size) size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs) return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale( def rescale(
...@@ -280,7 +282,7 @@ class MobileViTImageProcessor(BaseImageProcessor): ...@@ -280,7 +282,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
size = size if size is not None else self.size size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size) crop_size = get_size_dict(crop_size, param_name="crop_size")
if not is_batched(images): if not is_batched(images):
images = [images] images = [images]
......
...@@ -99,7 +99,7 @@ class PerceiverImageProcessor(BaseImageProcessor): ...@@ -99,7 +99,7 @@ class PerceiverImageProcessor(BaseImageProcessor):
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256} crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256}
crop_size = get_size_dict(crop_size) crop_size = get_size_dict(crop_size, param_name="crop_size")
size = size if size is not None else {"height": 224, "width": 224} size = size if size is not None else {"height": 224, "width": 224}
size = get_size_dict(size) size = get_size_dict(size)
...@@ -141,7 +141,7 @@ class PerceiverImageProcessor(BaseImageProcessor): ...@@ -141,7 +141,7 @@ class PerceiverImageProcessor(BaseImageProcessor):
""" """
size = self.size if size is None else size size = self.size if size is None else size
size = get_size_dict(size) size = get_size_dict(size)
crop_size = get_size_dict(crop_size) crop_size = get_size_dict(crop_size, param_name="crop_size")
height, width = get_image_size(image) height, width = get_image_size(image)
min_dim = min(height, width) min_dim = min(height, width)
...@@ -278,7 +278,7 @@ class PerceiverImageProcessor(BaseImageProcessor): ...@@ -278,7 +278,7 @@ class PerceiverImageProcessor(BaseImageProcessor):
""" """
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size) crop_size = get_size_dict(crop_size, param_name="crop_size")
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
size = get_size_dict(size) size = get_size_dict(size)
......
...@@ -122,7 +122,7 @@ class PoolFormerImageProcessor(BaseImageProcessor): ...@@ -122,7 +122,7 @@ class PoolFormerImageProcessor(BaseImageProcessor):
size = size if size is not None else {"shortest_edge": 224} size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size) crop_size = get_size_dict(crop_size, param_name="crop_size")
self.do_resize = do_resize self.do_resize = do_resize
self.size = size self.size = size
...@@ -218,6 +218,8 @@ class PoolFormerImageProcessor(BaseImageProcessor): ...@@ -218,6 +218,8 @@ class PoolFormerImageProcessor(BaseImageProcessor):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
""" """
size = get_size_dict(size) size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"size must contain 'height' and 'width' as keys. Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs) return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale( def rescale(
...@@ -335,7 +337,7 @@ class PoolFormerImageProcessor(BaseImageProcessor): ...@@ -335,7 +337,7 @@ class PoolFormerImageProcessor(BaseImageProcessor):
size = size if size is not None else self.size size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size) crop_size = get_size_dict(crop_size, param_name="crop_size")
if not is_batched(images): if not is_batched(images):
images = [images] images = [images]
......
...@@ -167,6 +167,8 @@ class SegformerImageProcessor(BaseImageProcessor): ...@@ -167,6 +167,8 @@ class SegformerImageProcessor(BaseImageProcessor):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
""" """
size = get_size_dict(size) size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs) return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale( def rescale(
......
...@@ -121,7 +121,7 @@ class VideoMAEImageProcessor(BaseImageProcessor): ...@@ -121,7 +121,7 @@ class VideoMAEImageProcessor(BaseImageProcessor):
size = size if size is not None else {"shortest_edge": 224} size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size) crop_size = get_size_dict(crop_size, param_name="crop_size")
self.do_resize = do_resize self.do_resize = do_resize
self.size = size self.size = size
...@@ -157,7 +157,7 @@ class VideoMAEImageProcessor(BaseImageProcessor): ...@@ -157,7 +157,7 @@ class VideoMAEImageProcessor(BaseImageProcessor):
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
""" """
size = get_size_dict(size) size = get_size_dict(size, default_to_square=False)
if "shortest_edge" in size: if "shortest_edge" in size:
output_size = get_resize_output_image_size(image, size["shortest_edge"], default_to_square=False) output_size = get_resize_output_image_size(image, size["shortest_edge"], default_to_square=False)
elif "height" in size and "width" in size: elif "height" in size and "width" in size:
...@@ -186,6 +186,8 @@ class VideoMAEImageProcessor(BaseImageProcessor): ...@@ -186,6 +186,8 @@ class VideoMAEImageProcessor(BaseImageProcessor):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
""" """
size = get_size_dict(size) size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"Size must have 'height' and 'width' as keys. Got {size.keys()}")
return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs) return center_crop(image, size=(size["height"], size["width"]), data_format=data_format, **kwargs)
def rescale( def rescale(
...@@ -346,7 +348,7 @@ class VideoMAEImageProcessor(BaseImageProcessor): ...@@ -346,7 +348,7 @@ class VideoMAEImageProcessor(BaseImageProcessor):
size = size if size is not None else self.size size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False) size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size) crop_size = get_size_dict(crop_size, param_name="crop_size")
if not valid_images(videos): if not valid_images(videos):
raise ValueError( raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment