Unverified Commit d68c4602 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Update defaults and logic to match old FE (#20065)

* Update defaults and logic to match old FE

* Use docker run rest values
parent c06d5556
...@@ -361,7 +361,6 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor): ...@@ -361,7 +361,6 @@ class LayoutLMv3ImageProcessor(BaseImageProcessor):
images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images] images = [self.normalize(image=image, mean=image_mean, std=image_std) for image in images]
# flip color channels from RGB to BGR (as Detectron2 requires this) # flip color channels from RGB to BGR (as Detectron2 requires this)
images = [flip_channel_order(image) for image in images]
images = [to_channel_dimension_format(image, data_format) for image in images] images = [to_channel_dimension_format(image, data_format) for image in images]
data = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) data = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
......
...@@ -24,8 +24,8 @@ from transformers.utils.generic import TensorType ...@@ -24,8 +24,8 @@ from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import center_crop, normalize, rescale, resize, to_channel_dimension_format from ...image_transforms import center_crop, normalize, rescale, resize, to_channel_dimension_format
from ...image_utils import ( from ...image_utils import (
IMAGENET_STANDARD_MEAN, IMAGENET_DEFAULT_MEAN,
IMAGENET_STANDARD_STD, IMAGENET_DEFAULT_STD,
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
...@@ -61,7 +61,7 @@ class PerceiverImageProcessor(BaseImageProcessor): ...@@ -61,7 +61,7 @@ class PerceiverImageProcessor(BaseImageProcessor):
parameter in the `preprocess` method. parameter in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`): size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
Size of the image after resizing. Can be overridden by the `size` parameter in the `preprocess` method. Size of the image after resizing. Can be overridden by the `size` parameter in the `preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Defines the resampling filter to use if resizing the image. Can be overridden by the `resample` parameter Defines the resampling filter to use if resizing the image. Can be overridden by the `resample` parameter
in the `preprocess` method. in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`): do_rescale (`bool`, *optional*, defaults to `True`):
...@@ -89,7 +89,7 @@ class PerceiverImageProcessor(BaseImageProcessor): ...@@ -89,7 +89,7 @@ class PerceiverImageProcessor(BaseImageProcessor):
crop_size: Dict[str, int] = None, crop_size: Dict[str, int] = None,
do_resize: bool = True, do_resize: bool = True,
size: Dict[str, int] = None, size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR, resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True, do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255, rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True, do_normalize: bool = True,
...@@ -111,8 +111,8 @@ class PerceiverImageProcessor(BaseImageProcessor): ...@@ -111,8 +111,8 @@ class PerceiverImageProcessor(BaseImageProcessor):
self.do_rescale = do_rescale self.do_rescale = do_rescale
self.rescale_factor = rescale_factor self.rescale_factor = rescale_factor
self.do_normalize = do_normalize self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
def center_crop( def center_crop(
self, self,
...@@ -153,7 +153,7 @@ class PerceiverImageProcessor(BaseImageProcessor): ...@@ -153,7 +153,7 @@ class PerceiverImageProcessor(BaseImageProcessor):
self, self,
image: np.ndarray, image: np.ndarray,
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PIL.Image.BILINEAR, resample: PILImageResampling = PIL.Image.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs **kwargs
) -> np.ndarray: ) -> np.ndarray:
...@@ -165,7 +165,7 @@ class PerceiverImageProcessor(BaseImageProcessor): ...@@ -165,7 +165,7 @@ class PerceiverImageProcessor(BaseImageProcessor):
Image to resize. Image to resize.
size (`Dict[str, int]`): size (`Dict[str, int]`):
Size of the output image. Size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PIL.Image.BILINEAR`): resample (`PILImageResampling`, *optional*, defaults to `PIL.Image.BICUBIC`):
Resampling filter to use when resizing the image. Resampling filter to use when resizing the image.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
......
...@@ -31,8 +31,8 @@ from ...image_transforms import ( ...@@ -31,8 +31,8 @@ from ...image_transforms import (
to_channel_dimension_format, to_channel_dimension_format,
) )
from ...image_utils import ( from ...image_utils import (
IMAGENET_STANDARD_MEAN, IMAGENET_DEFAULT_MEAN,
IMAGENET_STANDARD_STD, IMAGENET_DEFAULT_STD,
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
...@@ -133,8 +133,8 @@ class PoolFormerImageProcessor(BaseImageProcessor): ...@@ -133,8 +133,8 @@ class PoolFormerImageProcessor(BaseImageProcessor):
self.do_rescale = do_rescale self.do_rescale = do_rescale
self.rescale_factor = rescale_factor self.rescale_factor = rescale_factor
self.do_normalize = do_normalize self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
def resize( def resize(
self, self,
......
...@@ -25,8 +25,8 @@ from transformers.utils.generic import TensorType ...@@ -25,8 +25,8 @@ from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import center_crop, normalize, rescale, resize, to_channel_dimension_format from ...image_transforms import center_crop, normalize, rescale, resize, to_channel_dimension_format
from ...image_utils import ( from ...image_utils import (
IMAGENET_STANDARD_MEAN, IMAGENET_DEFAULT_MEAN,
IMAGENET_STANDARD_STD, IMAGENET_DEFAULT_STD,
ChannelDimension, ChannelDimension,
ImageInput, ImageInput,
PILImageResampling, PILImageResampling,
...@@ -115,15 +115,15 @@ class SegformerImageProcessor(BaseImageProcessor): ...@@ -115,15 +115,15 @@ class SegformerImageProcessor(BaseImageProcessor):
self.do_rescale = do_rescale self.do_rescale = do_rescale
self.rescale_factor = rescale_factor self.rescale_factor = rescale_factor
self.do_normalize = do_normalize self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_reduce_labels = do_reduce_labels self.do_reduce_labels = do_reduce_labels
def resize( def resize(
self, self,
image: np.ndarray, image: np.ndarray,
size: Dict[str, int], size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC, resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs **kwargs
) -> np.ndarray: ) -> np.ndarray:
...@@ -135,7 +135,7 @@ class SegformerImageProcessor(BaseImageProcessor): ...@@ -135,7 +135,7 @@ class SegformerImageProcessor(BaseImageProcessor):
Image to resize. Image to resize.
size (`Dict[str, int]`): size (`Dict[str, int]`):
Size of the output image. Size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PIL.Image.BICUBIC`): resample (`PILImageResampling`, *optional*, defaults to `PIL.Image.BILINEAR`):
Resampling filter to use when resiizing the image. Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*): data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image. The channel dimension format of the image. If not provided, it will be the same as the input image.
......
...@@ -903,7 +903,7 @@ class PerceiverModelIntegrationTest(unittest.TestCase): ...@@ -903,7 +903,7 @@ class PerceiverModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, model.config.num_labels)) expected_shape = torch.Size((1, model.config.num_labels))
self.assertEqual(logits.shape, expected_shape) self.assertEqual(logits.shape, expected_shape)
expected_slice = torch.tensor([-1.1653, -0.1993, -0.7521], device=torch_device) expected_slice = torch.tensor([-1.1652, -0.1992, -0.7520], device=torch_device)
self.assertTrue(torch.allclose(logits[0, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(logits[0, :3], expected_slice, atol=1e-4))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment