"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "4766e009b034ec698f09526a2225b6b5fc34a75e"
Unverified Commit df40edfb authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Make image processors more general (#27690)

* Make image processors more general

* Add backwards compatibility for KOSMOS-2

* Remove use_square_size everywhere

* Remove script
parent 96f9caa1
...@@ -84,10 +84,6 @@ class BitImageProcessor(BaseImageProcessor): ...@@ -84,10 +84,6 @@ class BitImageProcessor(BaseImageProcessor):
Can be overridden by the `image_std` parameter in the `preprocess` method. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`): do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB. Whether to convert the image to RGB.
use_square_size (`bool`, *optional*, defaults to `False`):
The value to be passed to `get_size_dict` as `default_to_square` when computing the image size. If the
`size` argument in `get_size_dict` is an `int`, it determines whether to default to a square image or not.
Note that this attribute is not used in computing `crop_size` via calling `get_size_dict`.
""" """
model_input_names = ["pixel_values"] model_input_names = ["pixel_values"]
...@@ -105,12 +101,11 @@ class BitImageProcessor(BaseImageProcessor): ...@@ -105,12 +101,11 @@ class BitImageProcessor(BaseImageProcessor):
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True, do_convert_rgb: bool = True,
use_square_size: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 224} size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=use_square_size) size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
...@@ -125,7 +120,6 @@ class BitImageProcessor(BaseImageProcessor): ...@@ -125,7 +120,6 @@ class BitImageProcessor(BaseImageProcessor):
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
self.do_convert_rgb = do_convert_rgb self.do_convert_rgb = do_convert_rgb
self.use_square_size = use_square_size
# Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize
def resize( def resize(
...@@ -153,13 +147,19 @@ class BitImageProcessor(BaseImageProcessor): ...@@ -153,13 +147,19 @@ class BitImageProcessor(BaseImageProcessor):
input_data_format (`ChannelDimension` or `str`, *optional*): input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred. The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size, default_to_square=self.use_square_size) default_to_square = True
if "shortest_edge" not in size: if "shortest_edge" in size:
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}") size = size["shortest_edge"]
default_to_square = False
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
else:
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
output_size = get_resize_output_image_size( output_size = get_resize_output_image_size(
image, image,
size=size["shortest_edge"], size=size,
default_to_square=self.use_square_size, default_to_square=default_to_square,
input_data_format=input_data_format, input_data_format=input_data_format,
) )
return resize( return resize(
...@@ -243,7 +243,7 @@ class BitImageProcessor(BaseImageProcessor): ...@@ -243,7 +243,7 @@ class BitImageProcessor(BaseImageProcessor):
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
size = get_size_dict(size, param_name="size", default_to_square=self.use_square_size) size = get_size_dict(size, param_name="size", default_to_square=False)
resample = resample if resample is not None else self.resample resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size
......
...@@ -84,10 +84,6 @@ class CLIPImageProcessor(BaseImageProcessor): ...@@ -84,10 +84,6 @@ class CLIPImageProcessor(BaseImageProcessor):
Can be overridden by the `image_std` parameter in the `preprocess` method. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`): do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB. Whether to convert the image to RGB.
use_square_size (`bool`, *optional*, defaults to `False`):
The value to be passed to `get_size_dict` as `default_to_square` when computing the image size. If the
`size` argument in `get_size_dict` is an `int`, it determines whether to default to a square image or not.
Note that this attribute is not used in computing `crop_size` via calling `get_size_dict`.
""" """
model_input_names = ["pixel_values"] model_input_names = ["pixel_values"]
...@@ -105,12 +101,11 @@ class CLIPImageProcessor(BaseImageProcessor): ...@@ -105,12 +101,11 @@ class CLIPImageProcessor(BaseImageProcessor):
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True, do_convert_rgb: bool = True,
use_square_size: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 224} size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=use_square_size) size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
...@@ -125,7 +120,10 @@ class CLIPImageProcessor(BaseImageProcessor): ...@@ -125,7 +120,10 @@ class CLIPImageProcessor(BaseImageProcessor):
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
self.do_convert_rgb = do_convert_rgb self.do_convert_rgb = do_convert_rgb
self.use_square_size = use_square_size
# for backwards compatibility of KOSMOS-2
if "use_square_size" in kwargs:
self.size = {"height": size["shortest_edge"], "width": size["shortest_edge"]}
def resize( def resize(
self, self,
...@@ -152,13 +150,19 @@ class CLIPImageProcessor(BaseImageProcessor): ...@@ -152,13 +150,19 @@ class CLIPImageProcessor(BaseImageProcessor):
input_data_format (`ChannelDimension` or `str`, *optional*): input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred. The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size, default_to_square=self.use_square_size) default_to_square = True
if "shortest_edge" not in size: if "shortest_edge" in size:
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}") size = size["shortest_edge"]
default_to_square = False
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
else:
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
output_size = get_resize_output_image_size( output_size = get_resize_output_image_size(
image, image,
size=size["shortest_edge"], size=size,
default_to_square=self.use_square_size, default_to_square=default_to_square,
input_data_format=input_data_format, input_data_format=input_data_format,
) )
return resize( return resize(
...@@ -242,7 +246,7 @@ class CLIPImageProcessor(BaseImageProcessor): ...@@ -242,7 +246,7 @@ class CLIPImageProcessor(BaseImageProcessor):
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
size = get_size_dict(size, param_name="size", default_to_square=self.use_square_size) size = get_size_dict(size, param_name="size", default_to_square=False)
resample = resample if resample is not None else self.resample resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size
......
...@@ -79,10 +79,6 @@ class MobileNetV1ImageProcessor(BaseImageProcessor): ...@@ -79,10 +79,6 @@ class MobileNetV1ImageProcessor(BaseImageProcessor):
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
use_square_size (`bool`, *optional*, defaults to `False`):
The value to be passed to `get_size_dict` as `default_to_square` when computing the image size. If the
`size` argument in `get_size_dict` is an `int`, it determines whether to default to a square image or not.
Note that this attribute is not used in computing `crop_size` via calling `get_size_dict`.
""" """
model_input_names = ["pixel_values"] model_input_names = ["pixel_values"]
...@@ -99,12 +95,11 @@ class MobileNetV1ImageProcessor(BaseImageProcessor): ...@@ -99,12 +95,11 @@ class MobileNetV1ImageProcessor(BaseImageProcessor):
do_normalize: bool = True, do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
use_square_size: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 256} size = size if size is not None else {"shortest_edge": 256}
size = get_size_dict(size, default_to_square=use_square_size) size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size) crop_size = get_size_dict(crop_size)
self.do_resize = do_resize self.do_resize = do_resize
...@@ -117,7 +112,6 @@ class MobileNetV1ImageProcessor(BaseImageProcessor): ...@@ -117,7 +112,6 @@ class MobileNetV1ImageProcessor(BaseImageProcessor):
self.do_normalize = do_normalize self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.use_square_size = use_square_size
# Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize
def resize( def resize(
...@@ -145,13 +139,19 @@ class MobileNetV1ImageProcessor(BaseImageProcessor): ...@@ -145,13 +139,19 @@ class MobileNetV1ImageProcessor(BaseImageProcessor):
input_data_format (`ChannelDimension` or `str`, *optional*): input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred. The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size, default_to_square=self.use_square_size) default_to_square = True
if "shortest_edge" not in size: if "shortest_edge" in size:
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}") size = size["shortest_edge"]
default_to_square = False
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
else:
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
output_size = get_resize_output_image_size( output_size = get_resize_output_image_size(
image, image,
size=size["shortest_edge"], size=size,
default_to_square=self.use_square_size, default_to_square=default_to_square,
input_data_format=input_data_format, input_data_format=input_data_format,
) )
return resize( return resize(
...@@ -231,7 +231,7 @@ class MobileNetV1ImageProcessor(BaseImageProcessor): ...@@ -231,7 +231,7 @@ class MobileNetV1ImageProcessor(BaseImageProcessor):
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=self.use_square_size) size = get_size_dict(size, default_to_square=False)
resample = resample if resample is not None else self.resample resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size
......
...@@ -83,10 +83,6 @@ class MobileNetV2ImageProcessor(BaseImageProcessor): ...@@ -83,10 +83,6 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
use_square_size (`bool`, *optional*, defaults to `False`):
The value to be passed to `get_size_dict` as `default_to_square` when computing the image size. If the
`size` argument in `get_size_dict` is an `int`, it determines whether to default to a square image or not.
Note that this attribute is not used in computing `crop_size` via calling `get_size_dict`.
""" """
model_input_names = ["pixel_values"] model_input_names = ["pixel_values"]
...@@ -103,12 +99,11 @@ class MobileNetV2ImageProcessor(BaseImageProcessor): ...@@ -103,12 +99,11 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
do_normalize: bool = True, do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
use_square_size: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 256} size = size if size is not None else {"shortest_edge": 256}
size = get_size_dict(size, default_to_square=use_square_size) size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size, param_name="crop_size") crop_size = get_size_dict(crop_size, param_name="crop_size")
self.do_resize = do_resize self.do_resize = do_resize
...@@ -121,7 +116,6 @@ class MobileNetV2ImageProcessor(BaseImageProcessor): ...@@ -121,7 +116,6 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
self.do_normalize = do_normalize self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.use_square_size = use_square_size
# Copied from transformers.models.mobilenet_v1.image_processing_mobilenet_v1.MobileNetV1ImageProcessor.resize # Copied from transformers.models.mobilenet_v1.image_processing_mobilenet_v1.MobileNetV1ImageProcessor.resize
def resize( def resize(
...@@ -149,13 +143,19 @@ class MobileNetV2ImageProcessor(BaseImageProcessor): ...@@ -149,13 +143,19 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
input_data_format (`ChannelDimension` or `str`, *optional*): input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred. The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size, default_to_square=self.use_square_size) default_to_square = True
if "shortest_edge" not in size: if "shortest_edge" in size:
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}") size = size["shortest_edge"]
default_to_square = False
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
else:
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
output_size = get_resize_output_image_size( output_size = get_resize_output_image_size(
image, image,
size=size["shortest_edge"], size=size,
default_to_square=self.use_square_size, default_to_square=default_to_square,
input_data_format=input_data_format, input_data_format=input_data_format,
) )
return resize( return resize(
...@@ -235,7 +235,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor): ...@@ -235,7 +235,7 @@ class MobileNetV2ImageProcessor(BaseImageProcessor):
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=self.use_square_size) size = get_size_dict(size, default_to_square=False)
resample = resample if resample is not None else self.resample resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size
......
...@@ -78,10 +78,6 @@ class MobileViTImageProcessor(BaseImageProcessor): ...@@ -78,10 +78,6 @@ class MobileViTImageProcessor(BaseImageProcessor):
do_flip_channel_order (`bool`, *optional*, defaults to `True`): do_flip_channel_order (`bool`, *optional*, defaults to `True`):
Whether to flip the color channels from RGB to BGR. Can be overridden by the `do_flip_channel_order` Whether to flip the color channels from RGB to BGR. Can be overridden by the `do_flip_channel_order`
parameter in the `preprocess` method. parameter in the `preprocess` method.
use_square_size (`bool`, *optional*, defaults to `False`):
The value to be passed to `get_size_dict` as `default_to_square` when computing the image size. If the
`size` argument in `get_size_dict` is an `int`, it determines whether to default to a square image or not.
Note that this attribute is not used in computing `crop_size` via calling `get_size_dict`.
""" """
model_input_names = ["pixel_values"] model_input_names = ["pixel_values"]
...@@ -96,12 +92,11 @@ class MobileViTImageProcessor(BaseImageProcessor): ...@@ -96,12 +92,11 @@ class MobileViTImageProcessor(BaseImageProcessor):
do_center_crop: bool = True, do_center_crop: bool = True,
crop_size: Dict[str, int] = None, crop_size: Dict[str, int] = None,
do_flip_channel_order: bool = True, do_flip_channel_order: bool = True,
use_square_size: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 224} size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=use_square_size) size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256} crop_size = crop_size if crop_size is not None else {"height": 256, "width": 256}
crop_size = get_size_dict(crop_size, param_name="crop_size") crop_size = get_size_dict(crop_size, param_name="crop_size")
...@@ -113,7 +108,6 @@ class MobileViTImageProcessor(BaseImageProcessor): ...@@ -113,7 +108,6 @@ class MobileViTImageProcessor(BaseImageProcessor):
self.do_center_crop = do_center_crop self.do_center_crop = do_center_crop
self.crop_size = crop_size self.crop_size = crop_size
self.do_flip_channel_order = do_flip_channel_order self.do_flip_channel_order = do_flip_channel_order
self.use_square_size = use_square_size
# Copied from transformers.models.mobilenet_v1.image_processing_mobilenet_v1.MobileNetV1ImageProcessor.resize with PILImageResampling.BICUBIC->PILImageResampling.BILINEAR # Copied from transformers.models.mobilenet_v1.image_processing_mobilenet_v1.MobileNetV1ImageProcessor.resize with PILImageResampling.BICUBIC->PILImageResampling.BILINEAR
def resize( def resize(
...@@ -141,13 +135,19 @@ class MobileViTImageProcessor(BaseImageProcessor): ...@@ -141,13 +135,19 @@ class MobileViTImageProcessor(BaseImageProcessor):
input_data_format (`ChannelDimension` or `str`, *optional*): input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred. The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size, default_to_square=self.use_square_size) default_to_square = True
if "shortest_edge" not in size: if "shortest_edge" in size:
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}") size = size["shortest_edge"]
default_to_square = False
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
else:
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
output_size = get_resize_output_image_size( output_size = get_resize_output_image_size(
image, image,
size=size["shortest_edge"], size=size,
default_to_square=self.use_square_size, default_to_square=default_to_square,
input_data_format=input_data_format, input_data_format=input_data_format,
) )
return resize( return resize(
...@@ -246,7 +246,7 @@ class MobileViTImageProcessor(BaseImageProcessor): ...@@ -246,7 +246,7 @@ class MobileViTImageProcessor(BaseImageProcessor):
) )
size = size if size is not None else self.size size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=self.use_square_size) size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size, param_name="crop_size") crop_size = get_size_dict(crop_size, param_name="crop_size")
......
...@@ -84,10 +84,6 @@ class ViTHybridImageProcessor(BaseImageProcessor): ...@@ -84,10 +84,6 @@ class ViTHybridImageProcessor(BaseImageProcessor):
Can be overridden by the `image_std` parameter in the `preprocess` method. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`): do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB. Whether to convert the image to RGB.
use_square_size (`bool`, *optional*, defaults to `False`):
The value to be passed to `get_size_dict` as `default_to_square` when computing the image size. If the
`size` argument in `get_size_dict` is an `int`, it determines whether to default to a square image or not.
Note that this attribute is not used in computing `crop_size` via calling `get_size_dict`.
""" """
model_input_names = ["pixel_values"] model_input_names = ["pixel_values"]
...@@ -105,12 +101,11 @@ class ViTHybridImageProcessor(BaseImageProcessor): ...@@ -105,12 +101,11 @@ class ViTHybridImageProcessor(BaseImageProcessor):
image_mean: Optional[Union[float, List[float]]] = None, image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True, do_convert_rgb: bool = True,
use_square_size: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 224} size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=use_square_size) size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
...@@ -125,7 +120,6 @@ class ViTHybridImageProcessor(BaseImageProcessor): ...@@ -125,7 +120,6 @@ class ViTHybridImageProcessor(BaseImageProcessor):
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
self.do_convert_rgb = do_convert_rgb self.do_convert_rgb = do_convert_rgb
self.use_square_size = use_square_size
# Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize
def resize( def resize(
...@@ -153,13 +147,19 @@ class ViTHybridImageProcessor(BaseImageProcessor): ...@@ -153,13 +147,19 @@ class ViTHybridImageProcessor(BaseImageProcessor):
input_data_format (`ChannelDimension` or `str`, *optional*): input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred. The channel dimension format of the input image. If not provided, it will be inferred.
""" """
size = get_size_dict(size, default_to_square=self.use_square_size) default_to_square = True
if "shortest_edge" not in size: if "shortest_edge" in size:
raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}") size = size["shortest_edge"]
default_to_square = False
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
else:
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
output_size = get_resize_output_image_size( output_size = get_resize_output_image_size(
image, image,
size=size["shortest_edge"], size=size,
default_to_square=self.use_square_size, default_to_square=default_to_square,
input_data_format=input_data_format, input_data_format=input_data_format,
) )
return resize( return resize(
...@@ -243,7 +243,7 @@ class ViTHybridImageProcessor(BaseImageProcessor): ...@@ -243,7 +243,7 @@ class ViTHybridImageProcessor(BaseImageProcessor):
""" """
do_resize = do_resize if do_resize is not None else self.do_resize do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size size = size if size is not None else self.size
size = get_size_dict(size, param_name="size", default_to_square=self.use_square_size) size = get_size_dict(size, param_name="size", default_to_square=False)
resample = resample if resample is not None else self.resample resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size crop_size = crop_size if crop_size is not None else self.crop_size
......
...@@ -55,7 +55,7 @@ class Kosmos2ProcessorTest(unittest.TestCase): ...@@ -55,7 +55,7 @@ class Kosmos2ProcessorTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.tmpdirname = tempfile.mkdtemp() self.tmpdirname = tempfile.mkdtemp()
image_processor = CLIPImageProcessor(use_square_size=True) image_processor = CLIPImageProcessor()
# We have a SentencePiece fixture for testing # We have a SentencePiece fixture for testing
slow_tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB) slow_tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment