Unverified Commit b56f17ae authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Added antialias option to transforms.functional.resize (#3761)

* WIP Added antialias option to transforms.functional.resize

* Updates according to the review

* Excluded these C++ files for iOS build

* Added support for mixed downsampling/upsampling

* Fixed heap overflow caused by explicit loop unrolling

* Applied PR review suggestions
- used pytest parametrize instead unittest
- cast to scalar_t ptr
- removed interpolate aa files for ios/android keeping original cmake version
parent d6fee5a4
......@@ -14,6 +14,13 @@ file(GLOB VISION_SRCS
../../torchvision/csrc/ops/*.h
../../torchvision/csrc/ops/*.cpp)
# Remove interpolate_aa sources as they are temporary code
# see https://github.com/pytorch/vision/pull/3761
# and IndexingUtils.h is unavailable on Android build
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../../torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp")
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../../torchvision/csrc/ops/interpolate_aa.cpp")
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../../torchvision/csrc/ops/interpolate_aa.h")
add_library(${TARGET} SHARED
${VISION_SRCS}
)
......
......@@ -11,6 +11,13 @@ file(GLOB VISION_SRCS
../torchvision/csrc/ops/*.h
../torchvision/csrc/ops/*.cpp)
# Remove interpolate_aa sources as they are temporary code
# see https://github.com/pytorch/vision/pull/3761
# and using TensorIterator unavailable with iOS
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp")
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../torchvision/csrc/ops/interpolate_aa.cpp")
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../torchvision/csrc/ops/interpolate_aa.h")
add_library(${TARGET} STATIC
${VISION_SRCS}
)
......
......@@ -1018,5 +1018,52 @@ def test_perspective_interpolation_warning(tester):
tester.assertTrue(res1.equal(res2))
@pytest.mark.parametrize('device', ["cpu", ])
@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize('size', [[96, 72], [96, 420], [420, 72]])
@pytest.mark.parametrize('interpolation', [BILINEAR, ])
def test_resize_antialias(device, dt, size, interpolation, tester):
if dt == torch.float16 and device == "cpu":
# skip float16 on CPU case
return
script_fn = torch.jit.script(F.resize)
tensor, pil_img = tester._create_data(320, 290, device=device)
if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases
tensor = tensor.to(dt)
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, antialias=True)
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation)
tester.assertEqual(
resized_tensor.size()[1:], resized_pil_img.size[::-1],
msg=f"{size}, {interpolation}, {dt}"
)
resized_tensor_f = resized_tensor
# we need to cast to uint8 to compare with PIL image
if resized_tensor_f.dtype == torch.uint8:
resized_tensor_f = resized_tensor_f.to(torch.float)
tester.approxEqualTensorToPIL(
resized_tensor_f, resized_pil_img, tol=0.5, msg=f"{size}, {interpolation}, {dt}"
)
tester.approxEqualTensorToPIL(
resized_tensor_f, resized_pil_img, tol=1.0 + 1e-5, agg_method="max",
msg=f"{size}, {interpolation}, {dt}"
)
if isinstance(size, int):
script_size = [size, ]
else:
script_size = size
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, antialias=True)
tester.assertTrue(resized_tensor.equal(resize_result), msg=f"{size}, {interpolation}, {dt}")
if __name__ == '__main__':
unittest.main()
......@@ -348,6 +348,10 @@ class Tester(unittest.TestCase):
self.assertEqual((owidth, oheight), result.size)
with self.assertWarnsRegex(UserWarning, r"Anti-alias option is always applied for PIL Image input"):
t = transforms.Resize(osize, antialias=False)
t(img)
def test_random_crop(self):
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
......
#include <ATen/TypeDefault.h>
#include <ATen/native/IndexingUtils.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/UpSample.h>
#include <cmath>
#include <vector>
#include <torch/library.h>
// Code temporary is in torchvision before merging it to PyTorch
namespace at {
namespace native {
namespace internal_upsample {
using scale_t = std::vector<c10::optional<double>>;
template <typename scalar_t, typename index_t>
static inline scalar_t interpolate_aa_single_dim_zero_strides(
char* src,
char** data,
int64_t i,
const index_t ids_stride) {
const index_t ids_min = *(index_t*)&data[0][0];
const index_t ids_size = *(index_t*)&data[1][0];
char* src_min = src + ids_min;
scalar_t t = *(scalar_t*)&src_min[0];
index_t wts_idx = *(index_t*)&data[4][0];
scalar_t* wts_ptr = (scalar_t*)&data[3][wts_idx];
scalar_t wts = wts_ptr[0];
scalar_t output = t * wts;
int j = 1;
for (; j < ids_size; j++) {
wts = wts_ptr[j];
t = *(scalar_t*)&src_min[j * ids_stride];
output += t * wts;
}
return output;
}
template <typename scalar_t, typename index_t>
static inline scalar_t interpolate_aa_single_dim(
char* src,
char** data,
const int64_t* strides,
int64_t i,
const index_t ids_stride) {
index_t ids_min = *(index_t*)&data[0][i * strides[0]];
index_t ids_size = *(index_t*)&data[1][i * strides[1]];
char* src_min = src + ids_min;
scalar_t t = *(scalar_t*)&src_min[0];
index_t wts_idx = *(index_t*)&data[4][i * strides[4]];
scalar_t* wts_ptr = (scalar_t*)&data[3][wts_idx];
scalar_t wts = wts_ptr[0];
scalar_t output = t * wts;
int j = 1;
for (; j < ids_size; j++) {
wts = wts_ptr[j];
t = *(scalar_t*)&src_min[j * ids_stride];
output += t * wts;
}
return output;
}
template <typename scalar_t, typename index_t>
static inline void basic_loop_aa_single_dim_zero_strides(
char** data,
const int64_t* strides,
int64_t n) {
char* dst = data[0];
char* src = data[1];
// index stride is constant for the given dimension
const index_t ids_stride = *(index_t*)&data[2 + 2][0];
for (int64_t i = 0; i < n; i++) {
*(scalar_t*)&dst[i * strides[0]] =
interpolate_aa_single_dim_zero_strides<scalar_t, index_t>(
src + i * strides[1], &data[2], i, ids_stride);
}
}
template <typename scalar_t, typename index_t>
static inline void basic_loop_aa_single_dim_nonzero_strides(
char** data,
const int64_t* strides,
int64_t n) {
char* dst = data[0];
char* src = data[1];
// index stride is constant for the given dimension
const index_t ids_stride = *(index_t*)&data[2 + 2][0];
if (strides[1] == 0) {
for (int64_t i = 0; i < n; i++) {
*(scalar_t*)&dst[i * strides[0]] =
interpolate_aa_single_dim<scalar_t, index_t>(
src, &data[2], &strides[2], i, ids_stride);
}
} else {
for (int64_t i = 0; i < n; i++) {
*(scalar_t*)&dst[i * strides[0]] =
interpolate_aa_single_dim<scalar_t, index_t>(
src + i * strides[1], &data[2], &strides[2], i, ids_stride);
}
}
}
template <int m>
static inline bool is_zero_stride(const int64_t* strides) {
bool output = strides[0] == 0;
for (int i = 1; i < m; i++) {
output &= (strides[i] == 0);
}
return output;
}
template <typename scalar_t, typename index_t, int out_ndims>
void ti_cpu_upsample_generic_aa(
at::TensorIterator& iter,
int interp_size = -1) {
TORCH_INTERNAL_ASSERT(interp_size > 0);
auto loop = [&](char** data, const int64_t* strides, int64_t n) {
if ((strides[0] == sizeof(scalar_t)) && (strides[1] == sizeof(scalar_t)) &&
is_zero_stride<3 + 2>(&strides[2])) {
basic_loop_aa_single_dim_zero_strides<scalar_t, index_t>(
data, strides, n);
} else {
basic_loop_aa_single_dim_nonzero_strides<scalar_t, index_t>(
data, strides, n);
}
};
iter.for_each(loop);
}
// Helper structs to use with ti_upsample_generic_Nd_kernel_impl
template <typename index_t, typename scalar_t>
struct HelperInterpBase {
static inline void init_indices_weights(
std::vector<Tensor>& output,
int64_t output_size,
int64_t ndims,
int64_t reshape_dim,
int interp_size) {
auto new_shape = std::vector<int64_t>(ndims, 1);
new_shape[reshape_dim] = output_size;
for (int j = 0; j < interp_size; j++) {
output.emplace_back(
empty(new_shape, CPU(c10::CppTypeToScalarType<index_t>())));
output.emplace_back(
empty(new_shape, CPU(c10::CppTypeToScalarType<scalar_t>())));
}
}
};
template <typename index_t, typename scalar_t>
struct HelperInterpLinear : public HelperInterpBase<index_t, scalar_t> {
static const int interp_size = 2;
static inline std::vector<Tensor> compute_indices_weights(
int64_t input_size,
int64_t output_size,
int64_t stride,
int64_t ndims,
int64_t reshape_dim,
bool align_corners,
const c10::optional<double> opt_scale,
bool antialias,
int& out_interp_size) {
scalar_t scale = area_pixel_compute_scale<scalar_t>(
input_size, output_size, align_corners, opt_scale);
TORCH_INTERNAL_ASSERT(antialias);
return _compute_indices_weights_aa(
input_size,
output_size,
stride,
ndims,
reshape_dim,
align_corners,
scale,
out_interp_size);
}
// taken from
// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/
// src/libImaging/Resample.c#L20-L29
static inline scalar_t _filter(scalar_t x) {
if (x < 0.0) {
x = -x;
}
if (x < 1.0) {
return 1.0 - x;
}
return 0.0;
}
static inline std::vector<Tensor> _compute_indices_weights_aa(
int64_t input_size,
int64_t output_size,
int64_t stride,
int64_t ndims,
int64_t reshape_dim,
bool align_corners,
scalar_t scale,
int& out_interp_size) {
int interp_size = HelperInterpLinear<index_t, scalar_t>::interp_size;
scalar_t support =
(scale >= 1.0) ? (interp_size / 2) * scale : interp_size / 2 * 1.0;
interp_size = (int)ceilf(support) * 2 + 1;
// return interp_size
out_interp_size = interp_size;
std::vector<Tensor> output;
auto new_shape = std::vector<int64_t>(ndims, 1);
new_shape[reshape_dim] = output_size;
// ---- Bounds approach as in PIL -----
// bounds: xmin/xmax
output.emplace_back(
empty(new_shape, CPU(c10::CppTypeToScalarType<index_t>())));
output.emplace_back(
empty(new_shape, CPU(c10::CppTypeToScalarType<index_t>())));
output.emplace_back(
empty(new_shape, CPU(c10::CppTypeToScalarType<index_t>())));
{
// Weights
new_shape[reshape_dim] = output_size * interp_size;
auto wts = empty(new_shape, CPU(c10::CppTypeToScalarType<scalar_t>()));
auto strides = wts.strides().vec();
strides[reshape_dim] = 0;
new_shape[reshape_dim] = output_size;
wts = wts.as_strided(new_shape, strides);
output.emplace_back(wts);
// Weights indices
output.emplace_back(
empty(new_shape, CPU(c10::CppTypeToScalarType<index_t>())));
}
scalar_t center, total_w, invscale = (scale >= 1.0) ? 1.0 / scale : 1.0;
index_t zero = static_cast<index_t>(0);
int64_t* idx_ptr_xmin = output[0].data_ptr<index_t>();
int64_t* idx_ptr_size = output[1].data_ptr<index_t>();
int64_t* idx_ptr_stride = output[2].data_ptr<index_t>();
scalar_t* wt_ptr = output[3].data_ptr<scalar_t>();
int64_t* wt_idx_ptr = output[4].data_ptr<index_t>();
int64_t xmin, xmax, j;
for (int64_t i = 0; i < output_size; i++) {
center = scale * (i + 0.5);
xmin = std::max(static_cast<int64_t>(center - support + 0.5), zero);
xmax =
std::min(static_cast<int64_t>(center + support + 0.5), input_size) -
xmin;
idx_ptr_xmin[i] = xmin * stride;
idx_ptr_size[i] = xmax;
idx_ptr_stride[i] = stride;
wt_idx_ptr[i] = i * interp_size * sizeof(scalar_t);
total_w = 0.0;
for (j = 0; j < xmax; j++) {
scalar_t w = _filter((j + xmin - center + 0.5) * invscale);
wt_ptr[i * interp_size + j] = w;
total_w += w;
}
for (j = 0; j < xmax; j++) {
if (total_w != 0.0) {
wt_ptr[i * interp_size + j] /= total_w;
}
}
for (; j < interp_size; j++) {
wt_ptr[i * interp_size + j] = static_cast<scalar_t>(0.0);
}
}
return output;
}
};
template <
typename index_t,
int out_ndims,
typename scale_type,
template <typename, typename>
class F>
void _ti_separable_upsample_generic_Nd_kernel_impl_single_dim(
Tensor& output,
const Tensor& input,
int interp_dim,
bool align_corners,
const scale_type& scales,
bool antialias) {
// input can be NCHW, NCL or NCKHW
auto shape = input.sizes().vec();
auto strides = input.strides().vec();
auto oshape = output.sizes();
TORCH_INTERNAL_ASSERT(
shape.size() == oshape.size() && shape.size() == 2 + out_ndims);
TORCH_INTERNAL_ASSERT(strides.size() == 2 + out_ndims);
TORCH_INTERNAL_ASSERT(antialias);
for (int i = 0; i < out_ndims; i++) {
shape[i + 2] = oshape[i + 2];
}
strides[interp_dim] = 0;
auto restrided_input = input.as_strided(shape, strides);
std::vector<std::vector<Tensor>> indices_weights;
int interp_size = F<index_t, float>::interp_size;
auto input_scalar_type = input.scalar_type();
if (interp_size == 1 && input_scalar_type == at::ScalarType::Byte) {
// nearest also supports uint8 tensor, but we have to use float
// with compute_indices_weights
input_scalar_type = at::ScalarType::Float;
}
AT_DISPATCH_FLOATING_TYPES_AND(
at::ScalarType::Byte,
input_scalar_type,
"compute_indices_weights_generic",
[&] {
indices_weights.emplace_back(
F<index_t, scalar_t>::compute_indices_weights(
input.size(interp_dim),
oshape[interp_dim],
input.stride(interp_dim) * input.element_size(),
input.dim(),
interp_dim,
align_corners,
scales[interp_dim - 2],
antialias,
interp_size));
});
TensorIteratorConfig config;
config.check_all_same_dtype(false)
.declare_static_dtype_and_device(input.scalar_type(), input.device())
.add_output(output)
.add_input(restrided_input);
for (auto& idx_weight : indices_weights) {
for (auto& tensor : idx_weight) {
config.add_input(tensor);
}
}
auto iter = config.build();
if (interp_size > 1) {
// Nearest also supports uint8 tensor, so need to handle it separately
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "upsample_generic_Nd", [&] {
ti_cpu_upsample_generic_aa<scalar_t, index_t, out_ndims>(
iter, interp_size);
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND(
at::ScalarType::Byte, iter.dtype(), "upsample_generic_Nd", [&] {
ti_cpu_upsample_generic_aa<scalar_t, index_t, out_ndims>(
iter, interp_size);
});
}
}
template <
typename index_t,
int out_ndims,
typename scale_type,
template <typename, typename>
class F>
void ti_separable_upsample_generic_Nd_kernel_impl(
Tensor& output,
const Tensor& input,
bool align_corners,
const scale_type& scales,
bool antialias) {
auto temp_oshape = input.sizes().vec();
at::Tensor temp_output, temp_input = input;
for (int i = 0; i < out_ndims - 1; i++) {
int interp_dim = 2 + out_ndims - 1 - i;
temp_oshape[interp_dim] = output.sizes()[interp_dim];
temp_output = at::empty(temp_oshape, input.options());
_ti_separable_upsample_generic_Nd_kernel_impl_single_dim<
index_t,
out_ndims,
scale_t,
HelperInterpLinear>(
temp_output, temp_input, interp_dim, align_corners, scales, antialias);
temp_input = temp_output;
}
_ti_separable_upsample_generic_Nd_kernel_impl_single_dim<
index_t,
out_ndims,
scale_t,
HelperInterpLinear>(
output, temp_input, 2, align_corners, scales, antialias);
}
void _ti_upsample_bilinear2d_kernel_impl(
Tensor& output,
const Tensor& input,
bool align_corners,
c10::optional<double> scales_h,
c10::optional<double> scales_w,
bool antialias) {
ti_separable_upsample_generic_Nd_kernel_impl<
int64_t,
2,
scale_t,
HelperInterpLinear>(
output, input, align_corners, {scales_h, scales_w}, antialias);
}
} // namespace internal_upsample
} // namespace native
} // namespace at
namespace vision {
namespace ops {
namespace {
at::Tensor interpolate_linear_aa_forward_kernel(
const at::Tensor& input,
at::IntArrayRef output_size,
bool align_corners) {
TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor");
c10::optional<c10::ArrayRef<double>> scale_factors = {};
// Copied from UpSampleBilinear2d.cpp
auto output = at::empty({0}, input.options());
auto osize = at::native::upsample::compute_output_size(
input.sizes(), output_size, scale_factors);
auto scale_h = at::native::upsample::get_scale_value(scale_factors, 0);
auto scale_w = at::native::upsample::get_scale_value(scale_factors, 1);
auto full_output_size =
at::native::upsample_2d_common_check(input.sizes(), osize);
// Allow for empty batch size but not other dimensions
TORCH_CHECK(
input.numel() != 0 ||
c10::multiply_integers(
input.sizes().begin() + 1, input.sizes().end()),
"Non-empty 4D data tensor expected but got a tensor with sizes ",
input.sizes());
output.resize_(full_output_size, input.suggest_memory_format());
at::native::internal_upsample::_ti_upsample_bilinear2d_kernel_impl(
output, input, align_corners, scale_h, scale_w, /*antialias=*/true);
return output;
}
// TODO: Implement backward function
// at::Tensor interpolate_linear_aa_backward_kernel(
// const at::Tensor& grad) {
// return grad_input;
// }
} // namespace
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_interpolate_linear_aa"),
TORCH_FN(interpolate_linear_aa_forward_kernel));
// TODO: Implement backward function
// m.impl(
// TORCH_SELECTIVE_NAME("torchvision::_interpolate_linear_aa_backward"),
// TORCH_FN(interpolate_linear_aa_backward_kernel));
}
} // namespace ops
} // namespace vision
#include "interpolate_aa.h"
#include <torch/types.h>
namespace vision {
namespace ops {
at::Tensor interpolate_linear_aa(
const at::Tensor& input, // Input image
at::IntArrayRef output_size, // Output image size
bool align_corners) // The flag to align corners
{
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::interpolate_linear_aa", "")
.typed<decltype(interpolate_linear_aa)>();
return op.call(input, output_size, align_corners);
}
namespace detail {
// TODO: Implement backward function
// at::Tensor _interpolate_linear_aa_backward(
// const at::Tensor& grad,
// at::IntArrayRef output_size,
// bool align_corners)
// {
// return at::Tensor();
// }
} // namespace detail
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_interpolate_linear_aa(Tensor input, int[] output_size, bool align_corners) -> Tensor"));
// TODO: Implement backward function
// m.def(TORCH_SELECTIVE_SCHEMA(
// "torchvision::_interpolate_linear_aa_backward(Tensor grad, Tensor rois,
// float spatial_scale, int pooled_height, int pooled_width, int
// batch_size, int channels, int height, int width, int sampling_ratio,
// bool aligned) -> Tensor"));
}
} // namespace ops
} // namespace vision
#pragma once
#include <ATen/ATen.h>
#include "../macros.h"
namespace vision {
namespace ops {
VISION_API at::Tensor _interpolate_linear_aa(
const at::Tensor& input,
at::IntArrayRef output_size,
bool align_corners = false);
namespace detail {
// TODO: Implement backward function
// at::Tensor _interpolate_linear_aa_backward(
// const at::Tensor& grad,
// at::IntArrayRef output_size,
// bool align_corners=false);
} // namespace detail
} // namespace ops
} // namespace vision
......@@ -341,7 +341,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None) -> Tensor:
max_size: Optional[int] = None, antialias: Optional[bool] = None) -> Tensor:
r"""Resize the input image to the given size.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
......@@ -375,6 +375,12 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte
smaller edge may be shorter than ``size``. This is only supported
if ``size`` is an int (or a sequence of length 1 in torchscript
mode).
antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias
is always used. If ``img`` is Tensor, the flag is False by default and can be set True for
``InterpolationMode.BILINEAR`` only mode.
.. warning::
There is no autodiff support for ``antialias=True`` option with input ``img`` as Tensor.
Returns:
PIL Image or Tensor: Resized image.
......@@ -391,10 +397,14 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte
raise TypeError("Argument interpolation should be a InterpolationMode")
if not isinstance(img, torch.Tensor):
if antialias is not None and not antialias:
warnings.warn(
"Anti-alias option is always applied for PIL Image input. Argument antialias is ignored."
)
pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.resize(img, size=size, interpolation=pil_interpolation, max_size=max_size)
return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size)
return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size, antialias=antialias)
def scale(*args, **kwargs):
......
......@@ -470,7 +470,13 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
return img
def resize(img: Tensor, size: List[int], interpolation: str = "bilinear", max_size: Optional[int] = None) -> Tensor:
def resize(
img: Tensor,
size: List[int],
interpolation: str = "bilinear",
max_size: Optional[int] = None,
antialias: Optional[bool] = None
) -> Tensor:
_assert_image_tensor(img)
if not isinstance(size, (int, tuple, list)):
......@@ -494,6 +500,12 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear", max_si
"i.e. size should be an int or a sequence of length 1 in torchscript mode."
)
if antialias is None:
antialias = False
if antialias and interpolation not in ["bilinear", ]:
raise ValueError("Antialias option is supported for bilinear interpolation mode only")
w, h = _get_image_size(img)
if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge
......@@ -524,7 +536,11 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear", max_si
# Define align_corners to avoid warnings
align_corners = False if interpolation in ["bilinear", "bicubic"] else None
img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners)
if antialias:
# Apply antialias for donwsampling on both dims
img = torch.ops.torchvision._interpolate_linear_aa(img, [new_h, new_w], align_corners=False)
else:
img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners)
if interpolation == "bicubic" and out_dtype == torch.uint8:
img = img.clamp(min=0, max=255)
......
......@@ -257,10 +257,16 @@ class Resize(torch.nn.Module):
smaller edge may be shorter than ``size``. This is only supported
if ``size`` is an int (or a sequence of length 1 in torchscript
mode).
antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias
is always used. If ``img`` is Tensor, the flag is False by default and can be set True for
``InterpolationMode.BILINEAR`` only mode.
.. warning::
There is no autodiff support for ``antialias=True`` option with input ``img`` as Tensor.
"""
def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None):
def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=None):
super().__init__()
if not isinstance(size, (int, Sequence)):
raise TypeError("Size should be int or sequence. Got {}".format(type(size)))
......@@ -278,6 +284,7 @@ class Resize(torch.nn.Module):
interpolation = _interpolation_modes_from_int(interpolation)
self.interpolation = interpolation
self.antialias = antialias
def forward(self, img):
"""
......@@ -287,12 +294,12 @@ class Resize(torch.nn.Module):
Returns:
PIL Image or Tensor: Rescaled image.
"""
return F.resize(img, self.size, self.interpolation, self.max_size)
return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
def __repr__(self):
interpolate_str = self.interpolation.value
return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2})'.format(
self.size, interpolate_str, self.max_size)
return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2}, antialias={3})'.format(
self.size, interpolate_str, self.max_size, self.antialias)
class Scale(Resize):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment