Unverified Commit b56f17ae authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Added antialias option to transforms.functional.resize (#3761)

* WIP Added antialias option to transforms.functional.resize

* Updates according to the review

* Excluded these C++ files for iOS build

* Added support for mixed downsampling/upsampling

* Fixed heap overflow caused by explicit loop unrolling

* Applied PR review suggestions
- used pytest parametrize instead unittest
- cast to scalar_t ptr
- removed interpolate aa files for ios/android keeping original cmake version
parent d6fee5a4
...@@ -14,6 +14,13 @@ file(GLOB VISION_SRCS ...@@ -14,6 +14,13 @@ file(GLOB VISION_SRCS
../../torchvision/csrc/ops/*.h ../../torchvision/csrc/ops/*.h
../../torchvision/csrc/ops/*.cpp) ../../torchvision/csrc/ops/*.cpp)
# Remove interpolate_aa sources as they are temporary code
# see https://github.com/pytorch/vision/pull/3761
# and IndexingUtils.h is unavailable on Android build
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../../torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp")
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../../torchvision/csrc/ops/interpolate_aa.cpp")
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../../torchvision/csrc/ops/interpolate_aa.h")
add_library(${TARGET} SHARED add_library(${TARGET} SHARED
${VISION_SRCS} ${VISION_SRCS}
) )
......
...@@ -11,6 +11,13 @@ file(GLOB VISION_SRCS ...@@ -11,6 +11,13 @@ file(GLOB VISION_SRCS
../torchvision/csrc/ops/*.h ../torchvision/csrc/ops/*.h
../torchvision/csrc/ops/*.cpp) ../torchvision/csrc/ops/*.cpp)
# Remove interpolate_aa sources as they are temporary code
# see https://github.com/pytorch/vision/pull/3761
# and using TensorIterator unavailable with iOS
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp")
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../torchvision/csrc/ops/interpolate_aa.cpp")
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../torchvision/csrc/ops/interpolate_aa.h")
add_library(${TARGET} STATIC add_library(${TARGET} STATIC
${VISION_SRCS} ${VISION_SRCS}
) )
......
...@@ -1018,5 +1018,52 @@ def test_perspective_interpolation_warning(tester): ...@@ -1018,5 +1018,52 @@ def test_perspective_interpolation_warning(tester):
tester.assertTrue(res1.equal(res2)) tester.assertTrue(res1.equal(res2))
@pytest.mark.parametrize('device', ["cpu", ])
@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize('size', [[96, 72], [96, 420], [420, 72]])
@pytest.mark.parametrize('interpolation', [BILINEAR, ])
def test_resize_antialias(device, dt, size, interpolation, tester):
if dt == torch.float16 and device == "cpu":
# skip float16 on CPU case
return
script_fn = torch.jit.script(F.resize)
tensor, pil_img = tester._create_data(320, 290, device=device)
if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases
tensor = tensor.to(dt)
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, antialias=True)
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation)
tester.assertEqual(
resized_tensor.size()[1:], resized_pil_img.size[::-1],
msg=f"{size}, {interpolation}, {dt}"
)
resized_tensor_f = resized_tensor
# we need to cast to uint8 to compare with PIL image
if resized_tensor_f.dtype == torch.uint8:
resized_tensor_f = resized_tensor_f.to(torch.float)
tester.approxEqualTensorToPIL(
resized_tensor_f, resized_pil_img, tol=0.5, msg=f"{size}, {interpolation}, {dt}"
)
tester.approxEqualTensorToPIL(
resized_tensor_f, resized_pil_img, tol=1.0 + 1e-5, agg_method="max",
msg=f"{size}, {interpolation}, {dt}"
)
if isinstance(size, int):
script_size = [size, ]
else:
script_size = size
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, antialias=True)
tester.assertTrue(resized_tensor.equal(resize_result), msg=f"{size}, {interpolation}, {dt}")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -348,6 +348,10 @@ class Tester(unittest.TestCase): ...@@ -348,6 +348,10 @@ class Tester(unittest.TestCase):
self.assertEqual((owidth, oheight), result.size) self.assertEqual((owidth, oheight), result.size)
with self.assertWarnsRegex(UserWarning, r"Anti-alias option is always applied for PIL Image input"):
t = transforms.Resize(osize, antialias=False)
t(img)
def test_random_crop(self): def test_random_crop(self):
height = random.randint(10, 32) * 2 height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2
......
#include <ATen/TypeDefault.h>
#include <ATen/native/IndexingUtils.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/UpSample.h>
#include <cmath>
#include <vector>
#include <torch/library.h>
// Code temporary is in torchvision before merging it to PyTorch
namespace at {
namespace native {
namespace internal_upsample {
using scale_t = std::vector<c10::optional<double>>;
template <typename scalar_t, typename index_t>
static inline scalar_t interpolate_aa_single_dim_zero_strides(
char* src,
char** data,
int64_t i,
const index_t ids_stride) {
const index_t ids_min = *(index_t*)&data[0][0];
const index_t ids_size = *(index_t*)&data[1][0];
char* src_min = src + ids_min;
scalar_t t = *(scalar_t*)&src_min[0];
index_t wts_idx = *(index_t*)&data[4][0];
scalar_t* wts_ptr = (scalar_t*)&data[3][wts_idx];
scalar_t wts = wts_ptr[0];
scalar_t output = t * wts;
int j = 1;
for (; j < ids_size; j++) {
wts = wts_ptr[j];
t = *(scalar_t*)&src_min[j * ids_stride];
output += t * wts;
}
return output;
}
template <typename scalar_t, typename index_t>
static inline scalar_t interpolate_aa_single_dim(
char* src,
char** data,
const int64_t* strides,
int64_t i,
const index_t ids_stride) {
index_t ids_min = *(index_t*)&data[0][i * strides[0]];
index_t ids_size = *(index_t*)&data[1][i * strides[1]];
char* src_min = src + ids_min;
scalar_t t = *(scalar_t*)&src_min[0];
index_t wts_idx = *(index_t*)&data[4][i * strides[4]];
scalar_t* wts_ptr = (scalar_t*)&data[3][wts_idx];
scalar_t wts = wts_ptr[0];
scalar_t output = t * wts;
int j = 1;
for (; j < ids_size; j++) {
wts = wts_ptr[j];
t = *(scalar_t*)&src_min[j * ids_stride];
output += t * wts;
}
return output;
}
template <typename scalar_t, typename index_t>
static inline void basic_loop_aa_single_dim_zero_strides(
char** data,
const int64_t* strides,
int64_t n) {
char* dst = data[0];
char* src = data[1];
// index stride is constant for the given dimension
const index_t ids_stride = *(index_t*)&data[2 + 2][0];
for (int64_t i = 0; i < n; i++) {
*(scalar_t*)&dst[i * strides[0]] =
interpolate_aa_single_dim_zero_strides<scalar_t, index_t>(
src + i * strides[1], &data[2], i, ids_stride);
}
}
template <typename scalar_t, typename index_t>
static inline void basic_loop_aa_single_dim_nonzero_strides(
char** data,
const int64_t* strides,
int64_t n) {
char* dst = data[0];
char* src = data[1];
// index stride is constant for the given dimension
const index_t ids_stride = *(index_t*)&data[2 + 2][0];
if (strides[1] == 0) {
for (int64_t i = 0; i < n; i++) {
*(scalar_t*)&dst[i * strides[0]] =
interpolate_aa_single_dim<scalar_t, index_t>(
src, &data[2], &strides[2], i, ids_stride);
}
} else {
for (int64_t i = 0; i < n; i++) {
*(scalar_t*)&dst[i * strides[0]] =
interpolate_aa_single_dim<scalar_t, index_t>(
src + i * strides[1], &data[2], &strides[2], i, ids_stride);
}
}
}
template <int m>
static inline bool is_zero_stride(const int64_t* strides) {
bool output = strides[0] == 0;
for (int i = 1; i < m; i++) {
output &= (strides[i] == 0);
}
return output;
}
template <typename scalar_t, typename index_t, int out_ndims>
void ti_cpu_upsample_generic_aa(
at::TensorIterator& iter,
int interp_size = -1) {
TORCH_INTERNAL_ASSERT(interp_size > 0);
auto loop = [&](char** data, const int64_t* strides, int64_t n) {
if ((strides[0] == sizeof(scalar_t)) && (strides[1] == sizeof(scalar_t)) &&
is_zero_stride<3 + 2>(&strides[2])) {
basic_loop_aa_single_dim_zero_strides<scalar_t, index_t>(
data, strides, n);
} else {
basic_loop_aa_single_dim_nonzero_strides<scalar_t, index_t>(
data, strides, n);
}
};
iter.for_each(loop);
}
// Helper structs to use with ti_upsample_generic_Nd_kernel_impl
template <typename index_t, typename scalar_t>
struct HelperInterpBase {
static inline void init_indices_weights(
std::vector<Tensor>& output,
int64_t output_size,
int64_t ndims,
int64_t reshape_dim,
int interp_size) {
auto new_shape = std::vector<int64_t>(ndims, 1);
new_shape[reshape_dim] = output_size;
for (int j = 0; j < interp_size; j++) {
output.emplace_back(
empty(new_shape, CPU(c10::CppTypeToScalarType<index_t>())));
output.emplace_back(
empty(new_shape, CPU(c10::CppTypeToScalarType<scalar_t>())));
}
}
};
template <typename index_t, typename scalar_t>
struct HelperInterpLinear : public HelperInterpBase<index_t, scalar_t> {
static const int interp_size = 2;
static inline std::vector<Tensor> compute_indices_weights(
int64_t input_size,
int64_t output_size,
int64_t stride,
int64_t ndims,
int64_t reshape_dim,
bool align_corners,
const c10::optional<double> opt_scale,
bool antialias,
int& out_interp_size) {
scalar_t scale = area_pixel_compute_scale<scalar_t>(
input_size, output_size, align_corners, opt_scale);
TORCH_INTERNAL_ASSERT(antialias);
return _compute_indices_weights_aa(
input_size,
output_size,
stride,
ndims,
reshape_dim,
align_corners,
scale,
out_interp_size);
}
// taken from
// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/
// src/libImaging/Resample.c#L20-L29
static inline scalar_t _filter(scalar_t x) {
if (x < 0.0) {
x = -x;
}
if (x < 1.0) {
return 1.0 - x;
}
return 0.0;
}
static inline std::vector<Tensor> _compute_indices_weights_aa(
int64_t input_size,
int64_t output_size,
int64_t stride,
int64_t ndims,
int64_t reshape_dim,
bool align_corners,
scalar_t scale,
int& out_interp_size) {
int interp_size = HelperInterpLinear<index_t, scalar_t>::interp_size;
scalar_t support =
(scale >= 1.0) ? (interp_size / 2) * scale : interp_size / 2 * 1.0;
interp_size = (int)ceilf(support) * 2 + 1;
// return interp_size
out_interp_size = interp_size;
std::vector<Tensor> output;
auto new_shape = std::vector<int64_t>(ndims, 1);
new_shape[reshape_dim] = output_size;
// ---- Bounds approach as in PIL -----
// bounds: xmin/xmax
output.emplace_back(
empty(new_shape, CPU(c10::CppTypeToScalarType<index_t>())));
output.emplace_back(
empty(new_shape, CPU(c10::CppTypeToScalarType<index_t>())));
output.emplace_back(
empty(new_shape, CPU(c10::CppTypeToScalarType<index_t>())));
{
// Weights
new_shape[reshape_dim] = output_size * interp_size;
auto wts = empty(new_shape, CPU(c10::CppTypeToScalarType<scalar_t>()));
auto strides = wts.strides().vec();
strides[reshape_dim] = 0;
new_shape[reshape_dim] = output_size;
wts = wts.as_strided(new_shape, strides);
output.emplace_back(wts);
// Weights indices
output.emplace_back(
empty(new_shape, CPU(c10::CppTypeToScalarType<index_t>())));
}
scalar_t center, total_w, invscale = (scale >= 1.0) ? 1.0 / scale : 1.0;
index_t zero = static_cast<index_t>(0);
int64_t* idx_ptr_xmin = output[0].data_ptr<index_t>();
int64_t* idx_ptr_size = output[1].data_ptr<index_t>();
int64_t* idx_ptr_stride = output[2].data_ptr<index_t>();
scalar_t* wt_ptr = output[3].data_ptr<scalar_t>();
int64_t* wt_idx_ptr = output[4].data_ptr<index_t>();
int64_t xmin, xmax, j;
for (int64_t i = 0; i < output_size; i++) {
center = scale * (i + 0.5);
xmin = std::max(static_cast<int64_t>(center - support + 0.5), zero);
xmax =
std::min(static_cast<int64_t>(center + support + 0.5), input_size) -
xmin;
idx_ptr_xmin[i] = xmin * stride;
idx_ptr_size[i] = xmax;
idx_ptr_stride[i] = stride;
wt_idx_ptr[i] = i * interp_size * sizeof(scalar_t);
total_w = 0.0;
for (j = 0; j < xmax; j++) {
scalar_t w = _filter((j + xmin - center + 0.5) * invscale);
wt_ptr[i * interp_size + j] = w;
total_w += w;
}
for (j = 0; j < xmax; j++) {
if (total_w != 0.0) {
wt_ptr[i * interp_size + j] /= total_w;
}
}
for (; j < interp_size; j++) {
wt_ptr[i * interp_size + j] = static_cast<scalar_t>(0.0);
}
}
return output;
}
};
template <
typename index_t,
int out_ndims,
typename scale_type,
template <typename, typename>
class F>
void _ti_separable_upsample_generic_Nd_kernel_impl_single_dim(
Tensor& output,
const Tensor& input,
int interp_dim,
bool align_corners,
const scale_type& scales,
bool antialias) {
// input can be NCHW, NCL or NCKHW
auto shape = input.sizes().vec();
auto strides = input.strides().vec();
auto oshape = output.sizes();
TORCH_INTERNAL_ASSERT(
shape.size() == oshape.size() && shape.size() == 2 + out_ndims);
TORCH_INTERNAL_ASSERT(strides.size() == 2 + out_ndims);
TORCH_INTERNAL_ASSERT(antialias);
for (int i = 0; i < out_ndims; i++) {
shape[i + 2] = oshape[i + 2];
}
strides[interp_dim] = 0;
auto restrided_input = input.as_strided(shape, strides);
std::vector<std::vector<Tensor>> indices_weights;
int interp_size = F<index_t, float>::interp_size;
auto input_scalar_type = input.scalar_type();
if (interp_size == 1 && input_scalar_type == at::ScalarType::Byte) {
// nearest also supports uint8 tensor, but we have to use float
// with compute_indices_weights
input_scalar_type = at::ScalarType::Float;
}
AT_DISPATCH_FLOATING_TYPES_AND(
at::ScalarType::Byte,
input_scalar_type,
"compute_indices_weights_generic",
[&] {
indices_weights.emplace_back(
F<index_t, scalar_t>::compute_indices_weights(
input.size(interp_dim),
oshape[interp_dim],
input.stride(interp_dim) * input.element_size(),
input.dim(),
interp_dim,
align_corners,
scales[interp_dim - 2],
antialias,
interp_size));
});
TensorIteratorConfig config;
config.check_all_same_dtype(false)
.declare_static_dtype_and_device(input.scalar_type(), input.device())
.add_output(output)
.add_input(restrided_input);
for (auto& idx_weight : indices_weights) {
for (auto& tensor : idx_weight) {
config.add_input(tensor);
}
}
auto iter = config.build();
if (interp_size > 1) {
// Nearest also supports uint8 tensor, so need to handle it separately
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "upsample_generic_Nd", [&] {
ti_cpu_upsample_generic_aa<scalar_t, index_t, out_ndims>(
iter, interp_size);
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND(
at::ScalarType::Byte, iter.dtype(), "upsample_generic_Nd", [&] {
ti_cpu_upsample_generic_aa<scalar_t, index_t, out_ndims>(
iter, interp_size);
});
}
}
template <
typename index_t,
int out_ndims,
typename scale_type,
template <typename, typename>
class F>
void ti_separable_upsample_generic_Nd_kernel_impl(
Tensor& output,
const Tensor& input,
bool align_corners,
const scale_type& scales,
bool antialias) {
auto temp_oshape = input.sizes().vec();
at::Tensor temp_output, temp_input = input;
for (int i = 0; i < out_ndims - 1; i++) {
int interp_dim = 2 + out_ndims - 1 - i;
temp_oshape[interp_dim] = output.sizes()[interp_dim];
temp_output = at::empty(temp_oshape, input.options());
_ti_separable_upsample_generic_Nd_kernel_impl_single_dim<
index_t,
out_ndims,
scale_t,
HelperInterpLinear>(
temp_output, temp_input, interp_dim, align_corners, scales, antialias);
temp_input = temp_output;
}
_ti_separable_upsample_generic_Nd_kernel_impl_single_dim<
index_t,
out_ndims,
scale_t,
HelperInterpLinear>(
output, temp_input, 2, align_corners, scales, antialias);
}
void _ti_upsample_bilinear2d_kernel_impl(
Tensor& output,
const Tensor& input,
bool align_corners,
c10::optional<double> scales_h,
c10::optional<double> scales_w,
bool antialias) {
ti_separable_upsample_generic_Nd_kernel_impl<
int64_t,
2,
scale_t,
HelperInterpLinear>(
output, input, align_corners, {scales_h, scales_w}, antialias);
}
} // namespace internal_upsample
} // namespace native
} // namespace at
namespace vision {
namespace ops {
namespace {
at::Tensor interpolate_linear_aa_forward_kernel(
const at::Tensor& input,
at::IntArrayRef output_size,
bool align_corners) {
TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor");
c10::optional<c10::ArrayRef<double>> scale_factors = {};
// Copied from UpSampleBilinear2d.cpp
auto output = at::empty({0}, input.options());
auto osize = at::native::upsample::compute_output_size(
input.sizes(), output_size, scale_factors);
auto scale_h = at::native::upsample::get_scale_value(scale_factors, 0);
auto scale_w = at::native::upsample::get_scale_value(scale_factors, 1);
auto full_output_size =
at::native::upsample_2d_common_check(input.sizes(), osize);
// Allow for empty batch size but not other dimensions
TORCH_CHECK(
input.numel() != 0 ||
c10::multiply_integers(
input.sizes().begin() + 1, input.sizes().end()),
"Non-empty 4D data tensor expected but got a tensor with sizes ",
input.sizes());
output.resize_(full_output_size, input.suggest_memory_format());
at::native::internal_upsample::_ti_upsample_bilinear2d_kernel_impl(
output, input, align_corners, scale_h, scale_w, /*antialias=*/true);
return output;
}
// TODO: Implement backward function
// at::Tensor interpolate_linear_aa_backward_kernel(
// const at::Tensor& grad) {
// return grad_input;
// }
} // namespace
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_interpolate_linear_aa"),
TORCH_FN(interpolate_linear_aa_forward_kernel));
// TODO: Implement backward function
// m.impl(
// TORCH_SELECTIVE_NAME("torchvision::_interpolate_linear_aa_backward"),
// TORCH_FN(interpolate_linear_aa_backward_kernel));
}
} // namespace ops
} // namespace vision
#include "interpolate_aa.h"
#include <torch/types.h>
namespace vision {
namespace ops {
at::Tensor interpolate_linear_aa(
const at::Tensor& input, // Input image
at::IntArrayRef output_size, // Output image size
bool align_corners) // The flag to align corners
{
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::interpolate_linear_aa", "")
.typed<decltype(interpolate_linear_aa)>();
return op.call(input, output_size, align_corners);
}
namespace detail {
// TODO: Implement backward function
// at::Tensor _interpolate_linear_aa_backward(
// const at::Tensor& grad,
// at::IntArrayRef output_size,
// bool align_corners)
// {
// return at::Tensor();
// }
} // namespace detail
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_interpolate_linear_aa(Tensor input, int[] output_size, bool align_corners) -> Tensor"));
// TODO: Implement backward function
// m.def(TORCH_SELECTIVE_SCHEMA(
// "torchvision::_interpolate_linear_aa_backward(Tensor grad, Tensor rois,
// float spatial_scale, int pooled_height, int pooled_width, int
// batch_size, int channels, int height, int width, int sampling_ratio,
// bool aligned) -> Tensor"));
}
} // namespace ops
} // namespace vision
#pragma once
#include <ATen/ATen.h>
#include "../macros.h"
namespace vision {
namespace ops {
VISION_API at::Tensor _interpolate_linear_aa(
const at::Tensor& input,
at::IntArrayRef output_size,
bool align_corners = false);
namespace detail {
// TODO: Implement backward function
// at::Tensor _interpolate_linear_aa_backward(
// const at::Tensor& grad,
// at::IntArrayRef output_size,
// bool align_corners=false);
} // namespace detail
} // namespace ops
} // namespace vision
...@@ -341,7 +341,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool ...@@ -341,7 +341,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None) -> Tensor: max_size: Optional[int] = None, antialias: Optional[bool] = None) -> Tensor:
r"""Resize the input image to the given size. r"""Resize the input image to the given size.
If the image is torch Tensor, it is expected If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
...@@ -375,6 +375,12 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte ...@@ -375,6 +375,12 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte
smaller edge may be shorter than ``size``. This is only supported smaller edge may be shorter than ``size``. This is only supported
if ``size`` is an int (or a sequence of length 1 in torchscript if ``size`` is an int (or a sequence of length 1 in torchscript
mode). mode).
antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias
is always used. If ``img`` is Tensor, the flag is False by default and can be set True for
``InterpolationMode.BILINEAR`` only mode.
.. warning::
There is no autodiff support for ``antialias=True`` option with input ``img`` as Tensor.
Returns: Returns:
PIL Image or Tensor: Resized image. PIL Image or Tensor: Resized image.
...@@ -391,10 +397,14 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte ...@@ -391,10 +397,14 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte
raise TypeError("Argument interpolation should be a InterpolationMode") raise TypeError("Argument interpolation should be a InterpolationMode")
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
if antialias is not None and not antialias:
warnings.warn(
"Anti-alias option is always applied for PIL Image input. Argument antialias is ignored."
)
pil_interpolation = pil_modes_mapping[interpolation] pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.resize(img, size=size, interpolation=pil_interpolation, max_size=max_size) return F_pil.resize(img, size=size, interpolation=pil_interpolation, max_size=max_size)
return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size) return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size, antialias=antialias)
def scale(*args, **kwargs): def scale(*args, **kwargs):
......
...@@ -470,7 +470,13 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con ...@@ -470,7 +470,13 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
return img return img
def resize(img: Tensor, size: List[int], interpolation: str = "bilinear", max_size: Optional[int] = None) -> Tensor: def resize(
img: Tensor,
size: List[int],
interpolation: str = "bilinear",
max_size: Optional[int] = None,
antialias: Optional[bool] = None
) -> Tensor:
_assert_image_tensor(img) _assert_image_tensor(img)
if not isinstance(size, (int, tuple, list)): if not isinstance(size, (int, tuple, list)):
...@@ -494,6 +500,12 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear", max_si ...@@ -494,6 +500,12 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear", max_si
"i.e. size should be an int or a sequence of length 1 in torchscript mode." "i.e. size should be an int or a sequence of length 1 in torchscript mode."
) )
if antialias is None:
antialias = False
if antialias and interpolation not in ["bilinear", ]:
raise ValueError("Antialias option is supported for bilinear interpolation mode only")
w, h = _get_image_size(img) w, h = _get_image_size(img)
if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge
...@@ -524,6 +536,10 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear", max_si ...@@ -524,6 +536,10 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear", max_si
# Define align_corners to avoid warnings # Define align_corners to avoid warnings
align_corners = False if interpolation in ["bilinear", "bicubic"] else None align_corners = False if interpolation in ["bilinear", "bicubic"] else None
if antialias:
# Apply antialias for donwsampling on both dims
img = torch.ops.torchvision._interpolate_linear_aa(img, [new_h, new_w], align_corners=False)
else:
img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners) img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners)
if interpolation == "bicubic" and out_dtype == torch.uint8: if interpolation == "bicubic" and out_dtype == torch.uint8:
......
...@@ -257,10 +257,16 @@ class Resize(torch.nn.Module): ...@@ -257,10 +257,16 @@ class Resize(torch.nn.Module):
smaller edge may be shorter than ``size``. This is only supported smaller edge may be shorter than ``size``. This is only supported
if ``size`` is an int (or a sequence of length 1 in torchscript if ``size`` is an int (or a sequence of length 1 in torchscript
mode). mode).
antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias
is always used. If ``img`` is Tensor, the flag is False by default and can be set True for
``InterpolationMode.BILINEAR`` only mode.
.. warning::
There is no autodiff support for ``antialias=True`` option with input ``img`` as Tensor.
""" """
def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None): def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=None):
super().__init__() super().__init__()
if not isinstance(size, (int, Sequence)): if not isinstance(size, (int, Sequence)):
raise TypeError("Size should be int or sequence. Got {}".format(type(size))) raise TypeError("Size should be int or sequence. Got {}".format(type(size)))
...@@ -278,6 +284,7 @@ class Resize(torch.nn.Module): ...@@ -278,6 +284,7 @@ class Resize(torch.nn.Module):
interpolation = _interpolation_modes_from_int(interpolation) interpolation = _interpolation_modes_from_int(interpolation)
self.interpolation = interpolation self.interpolation = interpolation
self.antialias = antialias
def forward(self, img): def forward(self, img):
""" """
...@@ -287,12 +294,12 @@ class Resize(torch.nn.Module): ...@@ -287,12 +294,12 @@ class Resize(torch.nn.Module):
Returns: Returns:
PIL Image or Tensor: Rescaled image. PIL Image or Tensor: Rescaled image.
""" """
return F.resize(img, self.size, self.interpolation, self.max_size) return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
def __repr__(self): def __repr__(self):
interpolate_str = self.interpolation.value interpolate_str = self.interpolation.value
return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2})'.format( return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2}, antialias={3})'.format(
self.size, interpolate_str, self.max_size) self.size, interpolate_str, self.max_size, self.antialias)
class Scale(Resize): class Scale(Resize):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment