Unverified Commit e2dbadbf authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[1/2] Added backward pass on CPU for interpolation with anti-alias option (#4208)



* WIP on backward op interpolation with AA

* Removed cuda tests and reformat cpp code

* Fixed clang wrong formatting

* Added channels last test case
Co-authored-by: default avatarvfdev-5 <vfdev-5@gmail.com>
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 30fd10bd
from functools import partial
import itertools import itertools
import os import os
import colorsys import colorsys
...@@ -578,6 +579,52 @@ def test_assert_resize_antialias(interpolation): ...@@ -578,6 +579,52 @@ def test_assert_resize_antialias(interpolation):
F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True) F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True)
@pytest.mark.parametrize('dt', [torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize('size', [[10, 7], [10, 42], [42, 7]])
@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC])
def test_interpolate_antialias_backward(dt, size, interpolation):
# temporarily hard-code device as CPU, CUDA support will be done later
device = "cpu"
if dt == torch.float16 and device == "cpu":
# skip float16 on CPU case
return
torch.manual_seed(12)
if interpolation == BILINEAR:
forward_op = torch.ops.torchvision._interpolate_bilinear2d_aa
backward_op = torch.ops.torchvision._interpolate_bilinear2d_aa_backward
elif interpolation == BICUBIC:
forward_op = torch.ops.torchvision._interpolate_bicubic2d_aa
backward_op = torch.ops.torchvision._interpolate_bicubic2d_aa_backward
class F(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result = forward_op(i, size, False)
ctx.save_for_backward(i, result)
return result
@staticmethod
def backward(ctx, grad_output):
i, result = ctx.saved_tensors
ishape = i.shape
oshape = result.shape[2:]
return backward_op(grad_output, oshape, ishape, False)
x = (
torch.rand(1, 32, 29, 3, dtype=torch.double, device=device).permute(0, 3, 1, 2).requires_grad_(True),
)
assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False)
x = (
torch.rand(1, 3, 32, 29, dtype=torch.double, device=device, requires_grad=True),
)
assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False)
def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype, tol=2.0 + 1e-10, agg_method="max"): def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype, tol=2.0 + 1e-10, agg_method="max"):
script_fn = torch.jit.script(fn) script_fn = torch.jit.script(fn)
......
#include <ATen/Parallel.h>
#include <ATen/TypeDefault.h> #include <ATen/TypeDefault.h>
#include <ATen/native/IndexingUtils.h> #include <ATen/native/IndexingUtils.h>
#include <ATen/native/TensorIterator.h> #include <ATen/native/TensorIterator.h>
...@@ -141,6 +142,41 @@ void ti_cpu_upsample_generic_aa( ...@@ -141,6 +142,41 @@ void ti_cpu_upsample_generic_aa(
// Helper structs to use with ti_upsample_generic_Nd_kernel_impl // Helper structs to use with ti_upsample_generic_Nd_kernel_impl
template <typename index_t, typename scalar_t> template <typename index_t, typename scalar_t>
struct HelperInterpBase { struct HelperInterpBase {
template <typename filter_fn_t>
static inline void _compute_weights_aa(
const int64_t i,
const int64_t input_size,
const scalar_t scale,
const scalar_t support,
scalar_t* wt_ptr,
const int64_t interp_size,
filter_fn_t filter_fn,
int64_t& xmin,
int64_t& xsize) {
scalar_t center = scale * (i + 0.5);
scalar_t total_w = 0.0;
scalar_t invscale = (scale >= 1.0) ? 1.0 / scale : 1.0;
xmin = std::max(
static_cast<int64_t>(center - support + 0.5), static_cast<index_t>(0));
xsize = std::min(static_cast<int64_t>(center + support + 0.5), input_size) -
xmin;
int64_t j = 0;
for (; j < xsize; j++) {
scalar_t w = filter_fn((j + xmin - center + 0.5) * invscale);
wt_ptr[j] = w;
total_w += w;
}
for (j = 0; j < xsize; j++) {
if (total_w != 0.0) {
wt_ptr[j] /= total_w;
}
}
for (; j < interp_size; j++) {
wt_ptr[j] = static_cast<scalar_t>(0.0);
}
}
template <typename filter_fn_t> template <typename filter_fn_t>
static inline std::vector<Tensor> _compute_indices_weights_aa( static inline std::vector<Tensor> _compute_indices_weights_aa(
int64_t input_size, int64_t input_size,
...@@ -187,43 +223,30 @@ struct HelperInterpBase { ...@@ -187,43 +223,30 @@ struct HelperInterpBase {
empty(new_shape, CPU(c10::CppTypeToScalarType<index_t>()))); empty(new_shape, CPU(c10::CppTypeToScalarType<index_t>())));
} }
scalar_t center, total_w, invscale = (scale >= 1.0) ? 1.0 / scale : 1.0;
index_t zero = static_cast<index_t>(0);
int64_t* idx_ptr_xmin = output[0].data_ptr<index_t>(); int64_t* idx_ptr_xmin = output[0].data_ptr<index_t>();
int64_t* idx_ptr_size = output[1].data_ptr<index_t>(); int64_t* idx_ptr_size = output[1].data_ptr<index_t>();
int64_t* idx_ptr_stride = output[2].data_ptr<index_t>(); int64_t* idx_ptr_stride = output[2].data_ptr<index_t>();
scalar_t* wt_ptr = output[3].data_ptr<scalar_t>(); scalar_t* wt_ptr = output[3].data_ptr<scalar_t>();
int64_t* wt_idx_ptr = output[4].data_ptr<index_t>(); int64_t* wt_idx_ptr = output[4].data_ptr<index_t>();
int64_t xmin, xmax, j; int64_t xmin, xmax;
for (int64_t i = 0; i < output_size; i++) { for (int64_t i = 0; i < output_size; i++) {
center = scale * (i + 0.5); HelperInterpBase<index_t, scalar_t>::_compute_weights_aa(
xmin = std::max(static_cast<int64_t>(center - support + 0.5), zero); i,
xmax = input_size,
std::min(static_cast<int64_t>(center + support + 0.5), input_size) - scale,
xmin; support,
wt_ptr + i * interp_size,
interp_size,
filter_fn,
xmin,
xmax);
idx_ptr_xmin[i] = xmin * stride; idx_ptr_xmin[i] = xmin * stride;
idx_ptr_size[i] = xmax; idx_ptr_size[i] = xmax;
idx_ptr_stride[i] = stride; idx_ptr_stride[i] = stride;
wt_idx_ptr[i] = i * interp_size * sizeof(scalar_t); wt_idx_ptr[i] = i * interp_size * sizeof(scalar_t);
total_w = 0.0;
for (j = 0; j < xmax; j++) {
scalar_t w = filter_fn((j + xmin - center + 0.5) * invscale);
wt_ptr[i * interp_size + j] = w;
total_w += w;
}
for (j = 0; j < xmax; j++) {
if (total_w != 0.0) {
wt_ptr[i * interp_size + j] /= total_w;
}
}
for (; j < interp_size; j++) {
wt_ptr[i * interp_size + j] = static_cast<scalar_t>(0.0);
}
} }
return output; return output;
} }
...@@ -475,6 +498,151 @@ void _ti_upsample_bicubic2d_kernel_impl( ...@@ -475,6 +498,151 @@ void _ti_upsample_bicubic2d_kernel_impl(
output, input, align_corners, {scales_h, scales_w}, antialias); output, input, align_corners, {scales_h, scales_w}, antialias);
} }
template <
typename scalar_t,
typename scale_type,
template <typename, typename>
class F>
void cpu_upsample_genNd_backward_aa(
const Tensor& grad_input_,
const Tensor& grad_output_,
bool align_corners,
const scale_type& scales) {
TORCH_CHECK(
grad_input_.dtype() == grad_output_.dtype(),
"expected dtype ",
grad_output_.dtype(),
" for `grad_input` but got dtype ",
grad_input_.dtype());
auto grad_output = grad_output_.contiguous();
auto grad_input = grad_input_.contiguous();
auto grad_output_data = grad_output.data_ptr<scalar_t>();
auto grad_input_data = grad_input.data_ptr<scalar_t>();
auto input_sizes = grad_input.sizes().vec();
auto output_sizes = grad_output.sizes().vec();
auto ndim = input_sizes.size();
// treat nbatch and channels as one dimension
int64_t channels = input_sizes[0] * input_sizes[1];
int64_t input_depth = (ndim == 5) ? input_sizes[2] : 1;
int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1;
int64_t input_height = (ndim >= 4) ? input_sizes[ndim - 2] : 1;
int64_t output_height = (ndim >= 4) ? output_sizes[ndim - 2] : 1;
int64_t input_width = input_sizes[ndim - 1];
int64_t output_width = output_sizes[ndim - 1];
int64_t output_slice_size = output_depth * output_height * output_width;
int interp_size = F<int64_t, float>::interp_size;
auto loop2d = [&](int64_t begin, int64_t end) {
const scalar_t height_scale = area_pixel_compute_scale<scalar_t>(
input_height, output_height, align_corners, scales[0]);
const scalar_t width_scale = area_pixel_compute_scale<scalar_t>(
input_width, output_width, align_corners, scales[1]);
auto input_indexr = [=](int64_t c, int64_t h, int64_t w) {
return grad_input_data + c * input_height * input_width +
h * input_width + w;
};
const scalar_t support_h = (height_scale >= 1.0)
? (interp_size * 0.5) * height_scale
: interp_size * 0.5;
const scalar_t support_w = (width_scale >= 1.0)
? (interp_size * 0.5) * width_scale
: interp_size * 0.5;
const int interp_height = (int)ceilf(support_h) * 2 + 1;
const int interp_width = (int)ceilf(support_w) * 2 + 1;
std::vector<scalar_t> wx(interp_width, 0.0);
std::vector<scalar_t> wy(interp_height, 0.0);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t xmin, ymin;
int64_t xsize, ysize;
auto filter_fn = F<int64_t, scalar_t>::_filter;
for (int64_t oh = 0; oh < output_height; oh++) {
F<int64_t, scalar_t>::_compute_weights_aa(
oh,
input_height,
height_scale,
support_h,
wy.data(),
interp_height,
filter_fn,
ymin,
ysize);
for (int64_t ow = 0; ow < output_width; ow++) {
F<int64_t, scalar_t>::_compute_weights_aa(
ow,
input_width,
width_scale,
support_w,
wx.data(),
interp_width,
filter_fn,
xmin,
xsize);
for (int64_t c = begin; c < end; c++) {
scalar_t grad_output_value =
grad_output_data[c * output_slice_size + oh * output_width + ow];
for (size_t y = 0; y < ysize; y++) {
for (size_t x = 0; x < xsize; x++) {
*input_indexr(c, ymin + y, xmin + x) +=
wx[x] * wy[y] * grad_output_value;
}
}
}
}
}
};
if (ndim == 4) {
// upsample bilinear 2d
at::parallel_for(
0, channels, at::internal::GRAIN_SIZE / output_slice_size / 4, loop2d);
} else {
TORCH_CHECK(false, "Unsupported tensor ndim");
}
if (!grad_input_.is_contiguous()) {
grad_input_.copy_(grad_input);
}
}
void _upsample_bilinear2d_aa_backward_kernel_impl(
const Tensor& grad_input,
const Tensor& grad_output,
bool align_corners,
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
AT_DISPATCH_FLOATING_TYPES(
grad_output.scalar_type(), "upsample_bilinear2d_backward_cpu", [&] {
cpu_upsample_genNd_backward_aa<scalar_t, scale_t, HelperInterpLinear>(
grad_input, grad_output, align_corners, {scales_h, scales_w});
});
}
void _upsample_bicubic2d_aa_backward_kernel_impl(
const Tensor& grad_input,
const Tensor& grad_output,
bool align_corners,
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
AT_DISPATCH_FLOATING_TYPES(
grad_output.scalar_type(), "upsample_bicubic2d_backward_cpu", [&] {
cpu_upsample_genNd_backward_aa<scalar_t, scale_t, HelperInterpCubic>(
grad_input, grad_output, align_corners, {scales_h, scales_w});
});
}
} // namespace internal_upsample } // namespace internal_upsample
} // namespace native } // namespace native
} // namespace at } // namespace at
...@@ -484,7 +652,7 @@ namespace ops { ...@@ -484,7 +652,7 @@ namespace ops {
namespace { namespace {
at::Tensor interpolate_linear_aa_forward_kernel( at::Tensor interpolate_bilinear2d_aa_forward_kernel(
const at::Tensor& input, const at::Tensor& input,
at::IntArrayRef output_size, at::IntArrayRef output_size,
bool align_corners) { bool align_corners) {
...@@ -515,7 +683,7 @@ at::Tensor interpolate_linear_aa_forward_kernel( ...@@ -515,7 +683,7 @@ at::Tensor interpolate_linear_aa_forward_kernel(
return output; return output;
} }
at::Tensor interpolate_bicubic_aa_forward_kernel( at::Tensor interpolate_bicubic2d_aa_forward_kernel(
const at::Tensor& input, const at::Tensor& input,
at::IntArrayRef output_size, at::IntArrayRef output_size,
bool align_corners) { bool align_corners) {
...@@ -546,26 +714,109 @@ at::Tensor interpolate_bicubic_aa_forward_kernel( ...@@ -546,26 +714,109 @@ at::Tensor interpolate_bicubic_aa_forward_kernel(
return output; return output;
} }
// TODO: Implement backward function at::Tensor interpolate_bilinear2d_aa_backward_kernel(
// at::Tensor interpolate_linear_aa_backward_kernel( const at::Tensor& grad_output,
// const at::Tensor& grad) { at::IntArrayRef output_size,
// return grad_input; at::IntArrayRef input_size,
// } bool align_corners) {
c10::optional<c10::ArrayRef<double>> scale_factors = {};
// Copied from UpSampleBilinear2d.cpp::upsample_bilinear2d_backward
auto grad_input = at::empty({0}, grad_output.options());
auto osize = at::native::upsample::compute_output_size(
input_size, output_size, scale_factors);
auto scale_h = at::native::upsample::get_scale_value(scale_factors, 0);
auto scale_w = at::native::upsample::get_scale_value(scale_factors, 1);
auto full_output_size =
at::native::upsample_2d_common_check(input_size, osize);
TORCH_CHECK(
grad_output.dim() == 4,
"Expected grad_output to be a tensor of dimension 4 but got: dimension ",
grad_output.dim());
for (int i = 0; i < 4; ++i) {
TORCH_CHECK(
grad_output.size(i) == full_output_size[i],
"Expected grad_output to have the same shape as output;",
" output.size(",
i,
") = ",
full_output_size[i],
" but got grad_output.size(",
i,
") = ",
grad_output.size(i));
}
grad_input.resize_(input_size, grad_output.suggest_memory_format());
grad_input.zero_();
at::native::internal_upsample::_upsample_bilinear2d_aa_backward_kernel_impl(
grad_input, grad_output, align_corners, scale_h, scale_w);
return grad_input;
}
at::Tensor interpolate_bicubic2d_aa_backward_kernel(
const at::Tensor& grad_output,
at::IntArrayRef output_size,
at::IntArrayRef input_size,
bool align_corners) {
c10::optional<c10::ArrayRef<double>> scale_factors = {};
// Copied from UpSampleBicubic2d.cpp::upsample_bicubic2d_backward
auto grad_input = at::empty({0}, grad_output.options());
auto osize = at::native::upsample::compute_output_size(
input_size, output_size, scale_factors);
auto scale_h = at::native::upsample::get_scale_value(scale_factors, 0);
auto scale_w = at::native::upsample::get_scale_value(scale_factors, 1);
auto full_output_size =
at::native::upsample_2d_common_check(input_size, osize);
TORCH_CHECK(
grad_output.dim() == 4,
"Expected grad_output to be a tensor of dimension 4 but got: dimension ",
grad_output.dim());
for (int i = 0; i < 4; ++i) {
TORCH_CHECK(
grad_output.size(i) == full_output_size[i],
"Expected grad_output to have the same shape as output;",
" output.size(",
i,
") = ",
full_output_size[i],
" but got grad_output.size(",
i,
") = ",
grad_output.size(i));
}
grad_input.resize_(input_size, grad_output.suggest_memory_format());
grad_input.zero_();
at::native::internal_upsample::_upsample_bicubic2d_aa_backward_kernel_impl(
grad_input, grad_output, align_corners, scale_h, scale_w);
return grad_input;
}
} // namespace } // namespace
TORCH_LIBRARY_IMPL(torchvision, CPU, m) { TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl( m.impl(
TORCH_SELECTIVE_NAME("torchvision::_interpolate_linear_aa"), TORCH_SELECTIVE_NAME("torchvision::_interpolate_bilinear2d_aa"),
TORCH_FN(interpolate_linear_aa_forward_kernel)); TORCH_FN(interpolate_bilinear2d_aa_forward_kernel));
m.impl( m.impl(
TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic_aa"), TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic2d_aa"),
TORCH_FN(interpolate_bicubic_aa_forward_kernel)); TORCH_FN(interpolate_bicubic2d_aa_forward_kernel));
m.impl(
// TODO: Implement backward function TORCH_SELECTIVE_NAME("torchvision::_interpolate_bilinear2d_aa_backward"),
// m.impl( TORCH_FN(interpolate_bilinear2d_aa_backward_kernel));
// TORCH_SELECTIVE_NAME("torchvision::_interpolate_linear_aa_backward"), m.impl(
// TORCH_FN(interpolate_linear_aa_backward_kernel)); TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic2d_aa_backward"),
TORCH_FN(interpolate_bicubic2d_aa_backward_kernel));
} }
} // namespace ops } // namespace ops
......
...@@ -62,23 +62,22 @@ __device__ __forceinline__ static accscalar_t bicubic_filter(accscalar_t x) { ...@@ -62,23 +62,22 @@ __device__ __forceinline__ static accscalar_t bicubic_filter(accscalar_t x) {
template <typename scalar_t, typename accscalar_t, typename filter_fn_t> template <typename scalar_t, typename accscalar_t, typename filter_fn_t>
__device__ __forceinline__ static void _compute_weights( __device__ __forceinline__ static void _compute_weights(
const int64_t i, const int i,
const int64_t input_size, const int input_size,
const accscalar_t scale, const accscalar_t scale,
const accscalar_t support, const accscalar_t support,
scalar_t* wt_ptr, scalar_t* wt_ptr,
int64_t interp_size, int interp_size,
filter_fn_t filter_fn, filter_fn_t filter_fn,
int64_t& xmin, int& xmin,
int64_t& xmax) { int& xmax) {
accscalar_t invscale = (scale >= 1.0) ? 1.0 / scale : 1.0; accscalar_t invscale = (scale >= 1.0) ? 1.0 / scale : 1.0;
accscalar_t center = scale * (i + 0.5); accscalar_t center = scale * (i + 0.5);
xmin = max( xmin = max(static_cast<int>(center - support + 0.5), static_cast<int>(0));
static_cast<int64_t>(center - support + 0.5), static_cast<int64_t>(0)); xmax = min(static_cast<int>(center + support + 0.5), input_size) - xmin;
xmax = min(static_cast<int64_t>(center + support + 0.5), input_size) - xmin;
accscalar_t total_w = 0.0; accscalar_t total_w = 0.0;
int64_t j = 0; int j = 0;
for (j = 0; j < xmax; j++) { for (j = 0; j < xmax; j++) {
accscalar_t w = filter_fn((j + xmin - center + 0.5) * invscale); accscalar_t w = filter_fn((j + xmin - center + 0.5) * invscale);
wt_ptr[j] = static_cast<scalar_t>(w); wt_ptr[j] = static_cast<scalar_t>(w);
...@@ -164,7 +163,7 @@ __global__ void upsample_gen2d_out_frame( ...@@ -164,7 +163,7 @@ __global__ void upsample_gen2d_out_frame(
scalar_t buffer2[256]; scalar_t buffer2[256];
// Compute weights // Compute weights
int64_t xmin, xsize, ymin, ysize; int xmin, xsize, ymin, ysize;
typedef scalar_t (*filter_fn_t)(scalar_t); typedef scalar_t (*filter_fn_t)(scalar_t);
if (interp_size == 2) { if (interp_size == 2) {
_compute_weights<scalar_t, accscalar_t, filter_fn_t>( _compute_weights<scalar_t, accscalar_t, filter_fn_t>(
...@@ -213,7 +212,7 @@ __global__ void upsample_gen2d_out_frame( ...@@ -213,7 +212,7 @@ __global__ void upsample_gen2d_out_frame(
for (int n = 0; n < batchsize; n++) { for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; ++c) { for (int c = 0; c < channels; ++c) {
// interpolate on x-axis for ymin to ymin + ysize // interpolate on x-axis for ymin to ymin + ysize
for (int64_t y = 0; y < ysize; y++) { for (int y = 0; y < ysize; y++) {
// copy data into the local buffer and use // copy data into the local buffer and use
// interpolate_aa_single_dim method // interpolate_aa_single_dim method
for (int x = 0; x < xsize; x++) { for (int x = 0; x < xsize; x++) {
...@@ -372,7 +371,7 @@ at::Tensor interpolate_gen2d_aa_forward_kernel( ...@@ -372,7 +371,7 @@ at::Tensor interpolate_gen2d_aa_forward_kernel(
return output; return output;
} }
at::Tensor interpolate_linear_aa_forward_kernel( at::Tensor interpolate_bilinear2d_aa_forward_kernel(
const at::Tensor& input, const at::Tensor& input,
at::IntArrayRef output_size, at::IntArrayRef output_size,
bool align_corners) { bool align_corners) {
...@@ -380,7 +379,7 @@ at::Tensor interpolate_linear_aa_forward_kernel( ...@@ -380,7 +379,7 @@ at::Tensor interpolate_linear_aa_forward_kernel(
input, output_size, align_corners); input, output_size, align_corners);
} }
at::Tensor interpolate_bicubic_aa_forward_kernel( at::Tensor interpolate_bicubic2d_aa_forward_kernel(
const at::Tensor& input, const at::Tensor& input,
at::IntArrayRef output_size, at::IntArrayRef output_size,
bool align_corners) { bool align_corners) {
...@@ -392,11 +391,11 @@ at::Tensor interpolate_bicubic_aa_forward_kernel( ...@@ -392,11 +391,11 @@ at::Tensor interpolate_bicubic_aa_forward_kernel(
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl( m.impl(
TORCH_SELECTIVE_NAME("torchvision::_interpolate_linear_aa"), TORCH_SELECTIVE_NAME("torchvision::_interpolate_bilinear2d_aa"),
TORCH_FN(interpolate_linear_aa_forward_kernel)); TORCH_FN(interpolate_bilinear2d_aa_forward_kernel));
m.impl( m.impl(
TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic_aa"), TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic2d_aa"),
TORCH_FN(interpolate_bicubic_aa_forward_kernel)); TORCH_FN(interpolate_bicubic2d_aa_forward_kernel));
} }
} // namespace ops } // namespace ops
......
...@@ -5,54 +5,69 @@ ...@@ -5,54 +5,69 @@
namespace vision { namespace vision {
namespace ops { namespace ops {
at::Tensor interpolate_linear_aa( at::Tensor _interpolate_bilinear2d_aa(
const at::Tensor& input, // Input image const at::Tensor& input, // Input image
at::IntArrayRef output_size, // Output image size at::IntArrayRef output_size, // Output image size
bool align_corners) // The flag to align corners bool align_corners) // The flag to align corners
{ {
static auto op = static auto op =
c10::Dispatcher::singleton() c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_interpolate_linear_aa", "") .findSchemaOrThrow("torchvision::_interpolate_bilinear2d_aa", "")
.typed<decltype(interpolate_linear_aa)>(); .typed<decltype(_interpolate_bilinear2d_aa)>();
return op.call(input, output_size, align_corners); return op.call(input, output_size, align_corners);
} }
at::Tensor interpolate_bicubic_aa( at::Tensor _interpolate_bicubic_aa(
const at::Tensor& input, // Input image const at::Tensor& input, // Input image
at::IntArrayRef output_size, // Output image size at::IntArrayRef output_size, // Output image size
bool align_corners) // The flag to align corners bool align_corners) // The flag to align corners
{ {
static auto op = static auto op =
c10::Dispatcher::singleton() c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_interpolate_bicubic_aa", "") .findSchemaOrThrow("torchvision::_interpolate_bicubic2d_aa", "")
.typed<decltype(_interpolate_bicubic_aa)>(); .typed<decltype(_interpolate_bicubic2d_aa)>();
return op.call(input, output_size, align_corners); return op.call(input, output_size, align_corners);
} }
namespace detail { namespace detail {
// TODO: Implement backward function at::Tensor _interpolate_bilinear2d_aa_backward(
// at::Tensor _interpolate_linear_aa_backward( const at::Tensor& grad_output,
// const at::Tensor& grad, at::IntArrayRef output_size,
// at::IntArrayRef output_size, at::IntArrayRef input_size,
// bool align_corners) bool align_corners) {
// { static auto op =
// return at::Tensor(); c10::Dispatcher::singleton()
// } .findSchemaOrThrow(
"torchvision::_interpolate_bilinear2d_aa_backward", "")
.typed<decltype(_interpolate_bilinear2d_aa_backward)>();
return op.call(grad_output, output_size, output_size, align_corners);
}
at::Tensor _interpolate_bicubic2d_aa_backward(
const at::Tensor& grad_output,
at::IntArrayRef output_size,
at::IntArrayRef input_size,
bool align_corners) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow(
"torchvision::_interpolate_bicubic2d_aa_backward", "")
.typed<decltype(_interpolate_bicubic2d_aa_backward)>();
return op.call(grad_output, output_size, output_size, align_corners);
}
} // namespace detail } // namespace detail
TORCH_LIBRARY_FRAGMENT(torchvision, m) { TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(TORCH_SELECTIVE_SCHEMA( m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_interpolate_linear_aa(Tensor input, int[] output_size, bool align_corners) -> Tensor")); "torchvision::_interpolate_bilinear2d_aa(Tensor input, int[] output_size, bool align_corners) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_interpolate_bicubic2d_aa(Tensor input, int[] output_size, bool align_corners) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_interpolate_bilinear2d_aa_backward(Tensor input, int[] output_size, int[] input_size, bool align_corners) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA( m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_interpolate_bicubic_aa(Tensor input, int[] output_size, bool align_corners) -> Tensor")); "torchvision::_interpolate_bicubic2d_aa_backward(Tensor input, int[] output_size, int[] input_size, bool align_corners) -> Tensor"));
// TODO: Implement backward function
// m.def(TORCH_SELECTIVE_SCHEMA(
// "torchvision::_interpolate_linear_aa_backward(Tensor grad, Tensor rois,
// float spatial_scale, int pooled_height, int pooled_width, int
// batch_size, int channels, int height, int width, int sampling_ratio,
// bool aligned) -> Tensor"));
} }
} // namespace ops } // namespace ops
......
...@@ -6,23 +6,29 @@ ...@@ -6,23 +6,29 @@
namespace vision { namespace vision {
namespace ops { namespace ops {
VISION_API at::Tensor _interpolate_linear_aa( VISION_API at::Tensor _interpolate_bilinear2d_aa(
const at::Tensor& input, const at::Tensor& input,
at::IntArrayRef output_size, at::IntArrayRef output_size,
bool align_corners = false); bool align_corners = false);
VISION_API at::Tensor _interpolate_bicubic_aa( VISION_API at::Tensor _interpolate_bicubic2d_aa(
const at::Tensor& input, const at::Tensor& input,
at::IntArrayRef output_size, at::IntArrayRef output_size,
bool align_corners = false); bool align_corners = false);
namespace detail { namespace detail {
// TODO: Implement backward function VISION_API at::Tensor _interpolate_bilinear2d_aa_backward(
// at::Tensor _interpolate_linear_aa_backward( const at::Tensor& grad,
// const at::Tensor& grad, at::IntArrayRef output_size,
// at::IntArrayRef output_size, at::IntArrayRef input_size,
// bool align_corners=false); bool align_corners = false);
VISION_API at::Tensor _interpolate_bicubic2d_aa_backward(
const at::Tensor& grad,
at::IntArrayRef output_size,
at::IntArrayRef input_size,
bool align_corners = false);
} // namespace detail } // namespace detail
......
...@@ -545,9 +545,9 @@ def resize( ...@@ -545,9 +545,9 @@ def resize(
if antialias: if antialias:
if interpolation == "bilinear": if interpolation == "bilinear":
img = torch.ops.torchvision._interpolate_linear_aa(img, [new_h, new_w], align_corners=False) img = torch.ops.torchvision._interpolate_bilinear2d_aa(img, [new_h, new_w], align_corners=False)
elif interpolation == "bicubic": elif interpolation == "bicubic":
img = torch.ops.torchvision._interpolate_bicubic_aa(img, [new_h, new_w], align_corners=False) img = torch.ops.torchvision._interpolate_bicubic2d_aa(img, [new_h, new_w], align_corners=False)
else: else:
img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners) img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment