Unverified Commit 26fe8fad authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Remove custom ops interpolation with antialiasing (#5329)



* Removed custom ops for interp with AA

* Fixed umft issues
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 0db67d85
...@@ -14,13 +14,6 @@ file(GLOB VISION_SRCS ...@@ -14,13 +14,6 @@ file(GLOB VISION_SRCS
../../torchvision/csrc/ops/*.h ../../torchvision/csrc/ops/*.h
../../torchvision/csrc/ops/*.cpp) ../../torchvision/csrc/ops/*.cpp)
# Remove interpolate_aa sources as they are temporary code
# see https://github.com/pytorch/vision/pull/3761
# and IndexingUtils.h is unavailable on Android build
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../../torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp")
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../../torchvision/csrc/ops/interpolate_aa.cpp")
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../../torchvision/csrc/ops/interpolate_aa.h")
add_library(${TARGET} SHARED add_library(${TARGET} SHARED
${VISION_SRCS} ${VISION_SRCS}
) )
......
...@@ -11,13 +11,6 @@ file(GLOB VISION_SRCS ...@@ -11,13 +11,6 @@ file(GLOB VISION_SRCS
../torchvision/csrc/ops/*.h ../torchvision/csrc/ops/*.h
../torchvision/csrc/ops/*.cpp) ../torchvision/csrc/ops/*.cpp)
# Remove interpolate_aa sources as they are temporary code
# see https://github.com/pytorch/vision/pull/3761
# and using TensorIterator unavailable with iOS
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp")
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../torchvision/csrc/ops/interpolate_aa.cpp")
list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../torchvision/csrc/ops/interpolate_aa.h")
add_library(${TARGET} STATIC add_library(${TARGET} STATIC
${VISION_SRCS} ${VISION_SRCS}
) )
......
...@@ -3,6 +3,7 @@ import itertools ...@@ -3,6 +3,7 @@ import itertools
import math import math
import os import os
import re import re
from functools import partial
from typing import Sequence from typing import Sequence
import numpy as np import numpy as np
...@@ -655,11 +656,13 @@ def test_resize_antialias(device, dt, size, interpolation): ...@@ -655,11 +656,13 @@ def test_resize_antialias(device, dt, size, interpolation):
def test_assert_resize_antialias(interpolation): def test_assert_resize_antialias(interpolation):
# Checks implementation on very large scales # Checks implementation on very large scales
# and catch TORCH_CHECK inside interpolate_aa_kernels.cu # and catch TORCH_CHECK inside PyTorch implementation
torch.manual_seed(12) torch.manual_seed(12)
tensor, pil_img = _create_data(1000, 1000, device="cuda") tensor, _ = _create_data(1000, 1000, device="cuda")
with pytest.raises(RuntimeError, match=r"Max supported scale factor is"): # Error message is not yet updated in pytorch nightly
# with pytest.raises(RuntimeError, match=r"Provided interpolation parameters can not be handled"):
with pytest.raises(RuntimeError, match=r"Too much shared memory required"):
F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True) F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True)
...@@ -674,32 +677,12 @@ def test_interpolate_antialias_backward(device, dt, size, interpolation): ...@@ -674,32 +677,12 @@ def test_interpolate_antialias_backward(device, dt, size, interpolation):
return return
torch.manual_seed(12) torch.manual_seed(12)
if interpolation == BILINEAR:
forward_op = torch.ops.torchvision._interpolate_bilinear2d_aa
backward_op = torch.ops.torchvision._interpolate_bilinear2d_aa_backward
elif interpolation == BICUBIC:
forward_op = torch.ops.torchvision._interpolate_bicubic2d_aa
backward_op = torch.ops.torchvision._interpolate_bicubic2d_aa_backward
class F(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result = forward_op(i, size, False)
ctx.save_for_backward(i, result)
return result
@staticmethod
def backward(ctx, grad_output):
i, result = ctx.saved_tensors
ishape = i.shape
oshape = result.shape[2:]
return backward_op(grad_output, oshape, ishape, False)
x = (torch.rand(1, 32, 29, 3, dtype=torch.double, device=device).permute(0, 3, 1, 2).requires_grad_(True),) x = (torch.rand(1, 32, 29, 3, dtype=torch.double, device=device).permute(0, 3, 1, 2).requires_grad_(True),)
assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False) resize = partial(F.resize, size=size, interpolation=interpolation, antialias=True)
assert torch.autograd.gradcheck(resize, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False)
x = (torch.rand(1, 3, 32, 29, dtype=torch.double, device=device, requires_grad=True),) x = (torch.rand(1, 3, 32, 29, dtype=torch.double, device=device, requires_grad=True),)
assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False) assert torch.autograd.gradcheck(resize, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False)
def check_functional_vs_PIL_vs_scripted( def check_functional_vs_PIL_vs_scripted(
......
#include <ATen/Parallel.h>
#include <ATen/TypeDefault.h>
#include <ATen/native/IndexingUtils.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/UpSample.h>
#include <cmath>
#include <vector>
#include <torch/library.h>
// Code temporary is in torchvision before merging it to PyTorch
namespace at {
namespace native {
namespace internal_upsample {
using scale_t = std::vector<c10::optional<double>>;
template <typename scalar_t, typename index_t>
static inline scalar_t interpolate_aa_single_dim_zero_strides(
char* src,
char** data,
int64_t i,
const index_t ids_stride) {
const index_t ids_min = *(index_t*)&data[0][0];
const index_t ids_size = *(index_t*)&data[1][0];
char* src_min = src + ids_min;
scalar_t t = *(scalar_t*)&src_min[0];
index_t wts_idx = *(index_t*)&data[4][0];
scalar_t* wts_ptr = (scalar_t*)&data[3][wts_idx];
scalar_t wts = wts_ptr[0];
scalar_t output = t * wts;
int j = 1;
for (; j < ids_size; j++) {
wts = wts_ptr[j];
t = *(scalar_t*)&src_min[j * ids_stride];
output += t * wts;
}
return output;
}
template <typename scalar_t, typename index_t>
static inline scalar_t interpolate_aa_single_dim(
char* src,
char** data,
const int64_t* strides,
int64_t i,
const index_t ids_stride) {
index_t ids_min = *(index_t*)&data[0][i * strides[0]];
index_t ids_size = *(index_t*)&data[1][i * strides[1]];
char* src_min = src + ids_min;
scalar_t t = *(scalar_t*)&src_min[0];
index_t wts_idx = *(index_t*)&data[4][i * strides[4]];
scalar_t* wts_ptr = (scalar_t*)&data[3][wts_idx];
scalar_t wts = wts_ptr[0];
scalar_t output = t * wts;
int j = 1;
for (; j < ids_size; j++) {
wts = wts_ptr[j];
t = *(scalar_t*)&src_min[j * ids_stride];
output += t * wts;
}
return output;
}
template <typename scalar_t, typename index_t>
static inline void basic_loop_aa_single_dim_zero_strides(
char** data,
const int64_t* strides,
int64_t n) {
char* dst = data[0];
char* src = data[1];
// index stride is constant for the given dimension
const index_t ids_stride = *(index_t*)&data[2 + 2][0];
for (int64_t i = 0; i < n; i++) {
*(scalar_t*)&dst[i * strides[0]] =
interpolate_aa_single_dim_zero_strides<scalar_t, index_t>(
src + i * strides[1], &data[2], i, ids_stride);
}
}
template <typename scalar_t, typename index_t>
static inline void basic_loop_aa_single_dim_nonzero_strides(
char** data,
const int64_t* strides,
int64_t n) {
char* dst = data[0];
char* src = data[1];
// index stride is constant for the given dimension
const index_t ids_stride = *(index_t*)&data[2 + 2][0];
if (strides[1] == 0) {
for (int64_t i = 0; i < n; i++) {
*(scalar_t*)&dst[i * strides[0]] =
interpolate_aa_single_dim<scalar_t, index_t>(
src, &data[2], &strides[2], i, ids_stride);
}
} else {
for (int64_t i = 0; i < n; i++) {
*(scalar_t*)&dst[i * strides[0]] =
interpolate_aa_single_dim<scalar_t, index_t>(
src + i * strides[1], &data[2], &strides[2], i, ids_stride);
}
}
}
template <int m>
static inline bool is_zero_stride(const int64_t* strides) {
bool output = strides[0] == 0;
for (int i = 1; i < m; i++) {
output &= (strides[i] == 0);
}
return output;
}
template <typename scalar_t, typename index_t, int out_ndims>
void ti_cpu_upsample_generic_aa(
at::TensorIterator& iter,
int interp_size = -1) {
TORCH_INTERNAL_ASSERT(interp_size > 0);
auto loop = [&](char** data, const int64_t* strides, int64_t n) {
if ((strides[0] == sizeof(scalar_t)) && (strides[1] == sizeof(scalar_t)) &&
is_zero_stride<3 + 2>(&strides[2])) {
basic_loop_aa_single_dim_zero_strides<scalar_t, index_t>(
data, strides, n);
} else {
basic_loop_aa_single_dim_nonzero_strides<scalar_t, index_t>(
data, strides, n);
}
};
iter.for_each(loop);
}
// Helper structs to use with ti_upsample_generic_Nd_kernel_impl
template <typename index_t, typename scalar_t>
struct HelperInterpBase {
template <typename filter_fn_t>
static inline void _compute_weights_aa(
const int64_t i,
const int64_t input_size,
const scalar_t scale,
const scalar_t support,
scalar_t* wt_ptr,
const int64_t interp_size,
filter_fn_t filter_fn,
int64_t& xmin,
int64_t& xsize) {
scalar_t center = scale * (i + 0.5);
scalar_t total_w = 0.0;
scalar_t invscale = (scale >= 1.0) ? 1.0 / scale : 1.0;
xmin = std::max(
static_cast<int64_t>(center - support + 0.5), static_cast<index_t>(0));
xsize = std::min(static_cast<int64_t>(center + support + 0.5), input_size) -
xmin;
int64_t j = 0;
for (; j < xsize; j++) {
scalar_t w = filter_fn((j + xmin - center + 0.5) * invscale);
wt_ptr[j] = w;
total_w += w;
}
for (j = 0; j < xsize; j++) {
if (total_w != 0.0) {
wt_ptr[j] /= total_w;
}
}
for (; j < interp_size; j++) {
wt_ptr[j] = static_cast<scalar_t>(0.0);
}
}
template <typename filter_fn_t>
static inline std::vector<Tensor> _compute_indices_weights_aa(
int64_t input_size,
int64_t output_size,
int64_t stride,
int64_t ndims,
int64_t reshape_dim,
bool align_corners,
scalar_t scale,
int& in_out_interp_size,
filter_fn_t filter_fn) {
int interp_size = in_out_interp_size;
scalar_t support =
(scale >= 1.0) ? (interp_size * 0.5) * scale : interp_size * 0.5;
interp_size = (int)ceilf(support) * 2 + 1;
// return interp_size
in_out_interp_size = interp_size;
std::vector<Tensor> output;
auto new_shape = std::vector<int64_t>(ndims, 1);
new_shape[reshape_dim] = output_size;
// ---- Bounds approach as in PIL -----
// bounds: xmin/xmax
output.emplace_back(
empty(new_shape, CPU(c10::CppTypeToScalarType<index_t>())));
output.emplace_back(
empty(new_shape, CPU(c10::CppTypeToScalarType<index_t>())));
output.emplace_back(
empty(new_shape, CPU(c10::CppTypeToScalarType<index_t>())));
{
// Weights
new_shape[reshape_dim] = output_size * interp_size;
auto wts = empty(new_shape, CPU(c10::CppTypeToScalarType<scalar_t>()));
auto strides = wts.strides().vec();
strides[reshape_dim] = 0;
new_shape[reshape_dim] = output_size;
wts = wts.as_strided(new_shape, strides);
output.emplace_back(wts);
// Weights indices
output.emplace_back(
empty(new_shape, CPU(c10::CppTypeToScalarType<index_t>())));
}
int64_t* idx_ptr_xmin = output[0].data_ptr<index_t>();
int64_t* idx_ptr_size = output[1].data_ptr<index_t>();
int64_t* idx_ptr_stride = output[2].data_ptr<index_t>();
scalar_t* wt_ptr = output[3].data_ptr<scalar_t>();
int64_t* wt_idx_ptr = output[4].data_ptr<index_t>();
int64_t xmin, xmax;
for (int64_t i = 0; i < output_size; i++) {
HelperInterpBase<index_t, scalar_t>::_compute_weights_aa(
i,
input_size,
scale,
support,
wt_ptr + i * interp_size,
interp_size,
filter_fn,
xmin,
xmax);
idx_ptr_xmin[i] = xmin * stride;
idx_ptr_size[i] = xmax;
idx_ptr_stride[i] = stride;
wt_idx_ptr[i] = i * interp_size * sizeof(scalar_t);
}
return output;
}
};
template <typename index_t, typename scalar_t>
struct HelperInterpLinear : public HelperInterpBase<index_t, scalar_t> {
static const int interp_size = 2;
// taken from
// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/
// src/libImaging/Resample.c#L20-L29
static inline scalar_t _filter(scalar_t x) {
if (x < 0.0) {
x = -x;
}
if (x < 1.0) {
return 1.0 - x;
}
return 0.0;
}
static inline std::vector<Tensor> compute_indices_weights(
int64_t input_size,
int64_t output_size,
int64_t stride,
int64_t ndims,
int64_t reshape_dim,
bool align_corners,
const c10::optional<double> opt_scale,
bool antialias,
int& out_interp_size) {
TORCH_INTERNAL_ASSERT(antialias);
scalar_t scale = area_pixel_compute_scale<scalar_t>(
input_size, output_size, align_corners, opt_scale);
out_interp_size = HelperInterpLinear<index_t, scalar_t>::interp_size;
return HelperInterpLinear<index_t, scalar_t>::_compute_indices_weights_aa(
input_size,
output_size,
stride,
ndims,
reshape_dim,
align_corners,
scale,
out_interp_size,
_filter);
}
};
template <typename index_t, typename scalar_t>
struct HelperInterpCubic : public HelperInterpBase<index_t, scalar_t> {
static const int interp_size = 4;
static inline std::vector<Tensor> compute_indices_weights(
int64_t input_size,
int64_t output_size,
int64_t stride,
int64_t ndims,
int64_t reshape_dim,
bool align_corners,
const c10::optional<double> opt_scale,
bool antialias,
int& out_interp_size) {
TORCH_INTERNAL_ASSERT(antialias);
scalar_t scale = area_pixel_compute_scale<scalar_t>(
input_size, output_size, align_corners, opt_scale);
out_interp_size = HelperInterpCubic<index_t, scalar_t>::interp_size;
return HelperInterpCubic<index_t, scalar_t>::_compute_indices_weights_aa(
input_size,
output_size,
stride,
ndims,
reshape_dim,
align_corners,
scale,
out_interp_size,
_filter);
}
// taken from
// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/
// src/libImaging/Resample.c#L46-L62
static inline scalar_t _filter(scalar_t x) {
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
#define a -0.5
if (x < 0.0) {
x = -x;
}
if (x < 1.0) {
return ((a + 2.0) * x - (a + 3.0)) * x * x + 1;
}
if (x < 2.0) {
return (((x - 5) * x + 8) * x - 4) * a;
}
return 0.0;
#undef a
}
};
template <
typename index_t,
int out_ndims,
typename scale_type,
template <typename, typename>
class F>
void _ti_separable_upsample_generic_Nd_kernel_impl_single_dim(
Tensor& output,
const Tensor& input,
int interp_dim,
bool align_corners,
const scale_type& scales,
bool antialias) {
// input can be NCHW, NCL or NCKHW
auto shape = input.sizes().vec();
auto strides = input.strides().vec();
auto oshape = output.sizes();
TORCH_INTERNAL_ASSERT(
shape.size() == oshape.size() && shape.size() == 2 + out_ndims);
TORCH_INTERNAL_ASSERT(strides.size() == 2 + out_ndims);
TORCH_INTERNAL_ASSERT(antialias);
for (int i = 0; i < out_ndims; i++) {
shape[i + 2] = oshape[i + 2];
}
strides[interp_dim] = 0;
auto restrided_input = input.as_strided(shape, strides);
std::vector<std::vector<Tensor>> indices_weights;
int interp_size = F<index_t, float>::interp_size;
auto input_scalar_type = input.scalar_type();
if (interp_size == 1 && input_scalar_type == at::ScalarType::Byte) {
// nearest also supports uint8 tensor, but we have to use float
// with compute_indices_weights
input_scalar_type = at::ScalarType::Float;
}
AT_DISPATCH_FLOATING_TYPES_AND(
at::ScalarType::Byte,
input_scalar_type,
"compute_indices_weights_generic",
[&] {
indices_weights.emplace_back(
F<index_t, scalar_t>::compute_indices_weights(
input.size(interp_dim),
oshape[interp_dim],
input.stride(interp_dim) * input.element_size(),
input.dim(),
interp_dim,
align_corners,
scales[interp_dim - 2],
antialias,
interp_size));
});
TensorIteratorConfig config;
config.check_all_same_dtype(false)
.declare_static_dtype_and_device(input.scalar_type(), input.device())
.add_output(output)
.add_input(restrided_input);
for (auto& idx_weight : indices_weights) {
for (auto& tensor : idx_weight) {
config.add_input(tensor);
}
}
auto iter = config.build();
if (interp_size > 1) {
// Nearest also supports uint8 tensor, so need to handle it separately
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "upsample_generic_Nd", [&] {
ti_cpu_upsample_generic_aa<scalar_t, index_t, out_ndims>(
iter, interp_size);
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND(
at::ScalarType::Byte, iter.dtype(), "upsample_generic_Nd", [&] {
ti_cpu_upsample_generic_aa<scalar_t, index_t, out_ndims>(
iter, interp_size);
});
}
}
template <
typename index_t,
int out_ndims,
typename scale_type,
template <typename, typename>
class F>
void ti_separable_upsample_generic_Nd_kernel_impl(
Tensor& output,
const Tensor& input,
bool align_corners,
const scale_type& scales,
bool antialias) {
auto temp_oshape = input.sizes().vec();
at::Tensor temp_output, temp_input = input;
for (int i = 0; i < out_ndims - 1; i++) {
int interp_dim = 2 + out_ndims - 1 - i;
temp_oshape[interp_dim] = output.sizes()[interp_dim];
temp_output = at::empty(temp_oshape, input.options());
_ti_separable_upsample_generic_Nd_kernel_impl_single_dim<
index_t,
out_ndims,
scale_t,
F>(
temp_output, temp_input, interp_dim, align_corners, scales, antialias);
temp_input = temp_output;
}
_ti_separable_upsample_generic_Nd_kernel_impl_single_dim<
index_t,
out_ndims,
scale_t,
F>(output, temp_input, 2, align_corners, scales, antialias);
}
void _ti_upsample_bilinear2d_kernel_impl(
Tensor& output,
const Tensor& input,
bool align_corners,
c10::optional<double> scales_h,
c10::optional<double> scales_w,
bool antialias) {
ti_separable_upsample_generic_Nd_kernel_impl<
int64_t,
2,
scale_t,
HelperInterpLinear>(
output, input, align_corners, {scales_h, scales_w}, antialias);
}
void _ti_upsample_bicubic2d_kernel_impl(
Tensor& output,
const Tensor& input,
bool align_corners,
c10::optional<double> scales_h,
c10::optional<double> scales_w,
bool antialias) {
ti_separable_upsample_generic_Nd_kernel_impl<
int64_t,
2,
scale_t,
HelperInterpCubic>(
output, input, align_corners, {scales_h, scales_w}, antialias);
}
template <
typename scalar_t,
typename scale_type,
template <typename, typename>
class F>
void cpu_upsample_genNd_backward_aa(
const Tensor& grad_input_,
const Tensor& grad_output_,
bool align_corners,
const scale_type& scales) {
TORCH_CHECK(
grad_input_.dtype() == grad_output_.dtype(),
"expected dtype ",
grad_output_.dtype(),
" for `grad_input` but got dtype ",
grad_input_.dtype());
auto grad_output = grad_output_.contiguous();
auto grad_input = grad_input_.contiguous();
auto grad_output_data = grad_output.data_ptr<scalar_t>();
auto grad_input_data = grad_input.data_ptr<scalar_t>();
auto input_sizes = grad_input.sizes().vec();
auto output_sizes = grad_output.sizes().vec();
auto ndim = input_sizes.size();
// treat nbatch and channels as one dimension
int64_t channels = input_sizes[0] * input_sizes[1];
int64_t input_depth = (ndim == 5) ? input_sizes[2] : 1;
int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1;
int64_t input_height = (ndim >= 4) ? input_sizes[ndim - 2] : 1;
int64_t output_height = (ndim >= 4) ? output_sizes[ndim - 2] : 1;
int64_t input_width = input_sizes[ndim - 1];
int64_t output_width = output_sizes[ndim - 1];
int64_t output_slice_size = output_depth * output_height * output_width;
int interp_size = F<int64_t, float>::interp_size;
auto loop2d = [&](int64_t begin, int64_t end) {
const scalar_t height_scale = area_pixel_compute_scale<scalar_t>(
input_height, output_height, align_corners, scales[0]);
const scalar_t width_scale = area_pixel_compute_scale<scalar_t>(
input_width, output_width, align_corners, scales[1]);
auto input_indexr = [=](int64_t c, int64_t h, int64_t w) {
return grad_input_data + c * input_height * input_width +
h * input_width + w;
};
const scalar_t support_h = (height_scale >= 1.0)
? (interp_size * 0.5) * height_scale
: interp_size * 0.5;
const scalar_t support_w = (width_scale >= 1.0)
? (interp_size * 0.5) * width_scale
: interp_size * 0.5;
const int interp_height = (int)ceilf(support_h) * 2 + 1;
const int interp_width = (int)ceilf(support_w) * 2 + 1;
std::vector<scalar_t> wx(interp_width, 0.0);
std::vector<scalar_t> wy(interp_height, 0.0);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t xmin, ymin;
int64_t xsize, ysize;
auto filter_fn = F<int64_t, scalar_t>::_filter;
for (int64_t oh = 0; oh < output_height; oh++) {
F<int64_t, scalar_t>::_compute_weights_aa(
oh,
input_height,
height_scale,
support_h,
wy.data(),
interp_height,
filter_fn,
ymin,
ysize);
for (int64_t ow = 0; ow < output_width; ow++) {
F<int64_t, scalar_t>::_compute_weights_aa(
ow,
input_width,
width_scale,
support_w,
wx.data(),
interp_width,
filter_fn,
xmin,
xsize);
for (int64_t c = begin; c < end; c++) {
scalar_t grad_output_value =
grad_output_data[c * output_slice_size + oh * output_width + ow];
for (size_t y = 0; y < ysize; y++) {
for (size_t x = 0; x < xsize; x++) {
*input_indexr(c, ymin + y, xmin + x) +=
wx[x] * wy[y] * grad_output_value;
}
}
}
}
}
};
if (ndim == 4) {
// upsample bilinear 2d
at::parallel_for(
0, channels, at::internal::GRAIN_SIZE / output_slice_size / 4, loop2d);
} else {
TORCH_CHECK(false, "Unsupported tensor ndim");
}
if (!grad_input_.is_contiguous()) {
grad_input_.copy_(grad_input);
}
}
void _upsample_bilinear2d_aa_backward_kernel_impl(
const Tensor& grad_input,
const Tensor& grad_output,
bool align_corners,
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
AT_DISPATCH_FLOATING_TYPES(
grad_output.scalar_type(), "upsample_bilinear2d_backward_cpu", [&] {
cpu_upsample_genNd_backward_aa<scalar_t, scale_t, HelperInterpLinear>(
grad_input, grad_output, align_corners, {scales_h, scales_w});
});
}
void _upsample_bicubic2d_aa_backward_kernel_impl(
const Tensor& grad_input,
const Tensor& grad_output,
bool align_corners,
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
AT_DISPATCH_FLOATING_TYPES(
grad_output.scalar_type(), "upsample_bicubic2d_backward_cpu", [&] {
cpu_upsample_genNd_backward_aa<scalar_t, scale_t, HelperInterpCubic>(
grad_input, grad_output, align_corners, {scales_h, scales_w});
});
}
} // namespace internal_upsample
} // namespace native
} // namespace at
namespace vision {
namespace ops {
namespace {
at::Tensor interpolate_bilinear2d_aa_forward_kernel(
const at::Tensor& input,
at::IntArrayRef output_size,
bool align_corners) {
TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor");
c10::optional<c10::ArrayRef<double>> scale_factors = {};
// Copied from UpSampleBilinear2d.cpp
auto output = at::empty({0}, input.options());
auto osize = at::native::upsample::compute_output_size(
input.sizes(), output_size, scale_factors);
auto scale_h = at::native::upsample::get_scale_value(scale_factors, 0);
auto scale_w = at::native::upsample::get_scale_value(scale_factors, 1);
auto full_output_size =
at::native::upsample_2d_common_check(input.sizes(), osize);
// Allow for empty batch size but not other dimensions
TORCH_CHECK(
input.numel() != 0 ||
c10::multiply_integers(
input.sizes().begin() + 1, input.sizes().end()),
"Non-empty 4D data tensor expected but got a tensor with sizes ",
input.sizes());
output.resize_(full_output_size, input.suggest_memory_format());
at::native::internal_upsample::_ti_upsample_bilinear2d_kernel_impl(
output, input, align_corners, scale_h, scale_w, /*antialias=*/true);
return output;
}
at::Tensor interpolate_bicubic2d_aa_forward_kernel(
const at::Tensor& input,
at::IntArrayRef output_size,
bool align_corners) {
TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor");
c10::optional<c10::ArrayRef<double>> scale_factors = {};
// Copied from UpSampleBilinear2d.cpp
auto output = at::empty({0}, input.options());
auto osize = at::native::upsample::compute_output_size(
input.sizes(), output_size, scale_factors);
auto scale_h = at::native::upsample::get_scale_value(scale_factors, 0);
auto scale_w = at::native::upsample::get_scale_value(scale_factors, 1);
auto full_output_size =
at::native::upsample_2d_common_check(input.sizes(), osize);
// Allow for empty batch size but not other dimensions
TORCH_CHECK(
input.numel() != 0 ||
c10::multiply_integers(
input.sizes().begin() + 1, input.sizes().end()),
"Non-empty 4D data tensor expected but got a tensor with sizes ",
input.sizes());
output.resize_(full_output_size, input.suggest_memory_format());
at::native::internal_upsample::_ti_upsample_bicubic2d_kernel_impl(
output, input, align_corners, scale_h, scale_w, /*antialias=*/true);
return output;
}
at::Tensor interpolate_bilinear2d_aa_backward_kernel(
const at::Tensor& grad_output,
at::IntArrayRef output_size,
at::IntArrayRef input_size,
bool align_corners) {
c10::optional<c10::ArrayRef<double>> scale_factors = {};
// Copied from UpSampleBilinear2d.cpp::upsample_bilinear2d_backward
auto grad_input = at::empty({0}, grad_output.options());
auto osize = at::native::upsample::compute_output_size(
input_size, output_size, scale_factors);
auto scale_h = at::native::upsample::get_scale_value(scale_factors, 0);
auto scale_w = at::native::upsample::get_scale_value(scale_factors, 1);
auto full_output_size =
at::native::upsample_2d_common_check(input_size, osize);
TORCH_CHECK(
grad_output.dim() == 4,
"Expected grad_output to be a tensor of dimension 4 but got: dimension ",
grad_output.dim());
for (int i = 0; i < 4; ++i) {
TORCH_CHECK(
grad_output.size(i) == full_output_size[i],
"Expected grad_output to have the same shape as output;",
" output.size(",
i,
") = ",
full_output_size[i],
" but got grad_output.size(",
i,
") = ",
grad_output.size(i));
}
grad_input.resize_(input_size, grad_output.suggest_memory_format());
grad_input.zero_();
at::native::internal_upsample::_upsample_bilinear2d_aa_backward_kernel_impl(
grad_input, grad_output, align_corners, scale_h, scale_w);
return grad_input;
}
at::Tensor interpolate_bicubic2d_aa_backward_kernel(
const at::Tensor& grad_output,
at::IntArrayRef output_size,
at::IntArrayRef input_size,
bool align_corners) {
c10::optional<c10::ArrayRef<double>> scale_factors = {};
// Copied from UpSampleBicubic2d.cpp::upsample_bicubic2d_backward
auto grad_input = at::empty({0}, grad_output.options());
auto osize = at::native::upsample::compute_output_size(
input_size, output_size, scale_factors);
auto scale_h = at::native::upsample::get_scale_value(scale_factors, 0);
auto scale_w = at::native::upsample::get_scale_value(scale_factors, 1);
auto full_output_size =
at::native::upsample_2d_common_check(input_size, osize);
TORCH_CHECK(
grad_output.dim() == 4,
"Expected grad_output to be a tensor of dimension 4 but got: dimension ",
grad_output.dim());
for (int i = 0; i < 4; ++i) {
TORCH_CHECK(
grad_output.size(i) == full_output_size[i],
"Expected grad_output to have the same shape as output;",
" output.size(",
i,
") = ",
full_output_size[i],
" but got grad_output.size(",
i,
") = ",
grad_output.size(i));
}
grad_input.resize_(input_size, grad_output.suggest_memory_format());
grad_input.zero_();
at::native::internal_upsample::_upsample_bicubic2d_aa_backward_kernel_impl(
grad_input, grad_output, align_corners, scale_h, scale_w);
return grad_input;
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_interpolate_bilinear2d_aa"),
TORCH_FN(interpolate_bilinear2d_aa_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic2d_aa"),
TORCH_FN(interpolate_bicubic2d_aa_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_interpolate_bilinear2d_aa_backward"),
TORCH_FN(interpolate_bilinear2d_aa_backward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic2d_aa_backward"),
TORCH_FN(interpolate_bicubic2d_aa_backward_kernel));
}
} // namespace ops
} // namespace vision
#include <torch/library.h>
// Copied and adapted from
// Adapted from interp.cpp from Caffe util by Pauline Luc
// Originally developed by George Papandreou
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/NativeFunctions.h>
#include <ATen/TensorUtils.h>
#include <ATen/Utils.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/native/cuda/KernelUtils.cuh>
#include <ATen/native/cuda/UpSample.cuh>
// Below is experimental temporary code before merging it to PyTorch
namespace at {
namespace native {
namespace internal_upsample {
__device__ __forceinline__ size_t
idx(const size_t nc,
const size_t height,
const size_t width,
const size_t y,
const size_t x) {
return (nc * height + y) * width + x;
}
// taken from
// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/
// src/libImaging/Resample.c#L20-L29
template <typename accscalar_t>
__device__ __forceinline__ static accscalar_t bilinear_filter(accscalar_t x) {
if (x < 0.0) {
x = -x;
}
if (x < 1.0) {
return static_cast<accscalar_t>(1.0) - x;
}
return static_cast<accscalar_t>(0.0);
}
// taken from
// https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/
// src/libImaging/Resample.c#L46-L62
template <typename accscalar_t>
__device__ __forceinline__ static accscalar_t bicubic_filter(accscalar_t x) {
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
#define a -0.5
if (x < 0.0) {
x = -x;
}
if (x < 1.0) {
return ((a + 2.0) * x - (a + 3.0)) * x * x + static_cast<accscalar_t>(1.0);
}
if (x < 2.0) {
return (((x - 5) * x + 8) * x - 4) * a;
}
return static_cast<accscalar_t>(0.0);
#undef a
}
template <typename scalar_t, typename accscalar_t, typename filter_fn_t>
__device__ __forceinline__ static void _compute_weights(
const int i,
const int input_size,
const accscalar_t scale,
const accscalar_t support,
scalar_t* wt_ptr,
int interp_size,
filter_fn_t filter_fn,
int& xmin,
int& xmax) {
accscalar_t invscale = (scale >= 1.0) ? 1.0 / scale : 1.0;
accscalar_t center = scale * (i + 0.5);
xmin = max(static_cast<int>(center - support + 0.5), static_cast<int>(0));
xmax = min(static_cast<int>(center + support + 0.5), input_size) - xmin;
accscalar_t total_w = 0.0;
int j = 0;
for (j = 0; j < xmax; j++) {
accscalar_t w = filter_fn((j + xmin - center + 0.5) * invscale);
wt_ptr[j] = static_cast<scalar_t>(w);
total_w += w;
}
for (j = 0; j < xmax; j++) {
if (total_w != 0.0) {
wt_ptr[j] /= total_w;
}
}
for (; j < interp_size; j++) {
wt_ptr[j] = static_cast<scalar_t>(0.0);
}
}
template <typename scalar_t, typename accscalar_t>
__device__ __forceinline__ static accscalar_t interpolate_aa_single_dim(
scalar_t* src,
scalar_t* weights,
int64_t size) {
scalar_t t = static_cast<accscalar_t>(*src);
scalar_t wts = static_cast<accscalar_t>(weights[0]);
accscalar_t output = t * wts;
int64_t j = 1;
for (; j < size; j++) {
wts = static_cast<accscalar_t>(weights[j]);
t = static_cast<accscalar_t>(*(src + j));
output += t * wts;
}
return output;
}
template <typename scalar_t, typename accscalar_t, int interp_size>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void upsample_gen2d_out_frame(
const int n,
const accscalar_t rheight,
const accscalar_t rwidth,
const bool align_corners,
const PackedTensorAccessor64<scalar_t, 4> idata,
PackedTensorAccessor64<scalar_t, 4> odata) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
const int batchsize = idata.size(0);
const int channels = idata.size(1);
const int height1 = idata.size(2);
const int width1 = idata.size(3);
const int height2 = odata.size(2);
const int width2 = odata.size(3);
if (index < n) {
const int w2 = index % width2; // 0:width2-1
const int h2 = index / width2; // 0:height2-1
// special case: just copy
if (height1 == height2 && width1 == width2) {
const int h1 = h2;
const int w1 = w2;
for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; ++c) {
const scalar_t val = idata[n][c][h1][w1];
odata[n][c][h2][w2] = val;
}
}
return;
}
const accscalar_t support_h = static_cast<accscalar_t>(
(rheight >= 1.0) ? (interp_size * 0.5) * rheight : interp_size * 0.5);
const accscalar_t support_w = static_cast<accscalar_t>(
(rwidth >= 1.0) ? (interp_size * 0.5) * rwidth : interp_size * 0.5);
const int interp_height = (int)ceilf(support_h) * 2 + 1;
const int interp_width = (int)ceilf(support_w) * 2 + 1;
// Setup local buffers
// TODO: maybe we can specify dynamic shared memory size before calling the
// cuda code, however we should then ensure that device has enough shared
// memory
scalar_t wx[256];
scalar_t wy[256];
scalar_t buffer1[256];
scalar_t buffer2[256];
// Compute weights
int xmin, xsize, ymin, ysize;
typedef scalar_t (*filter_fn_t)(scalar_t);
filter_fn_t filter_fn;
if (interp_size == 2) {
filter_fn = bilinear_filter;
} else if (interp_size == 4) {
filter_fn = bicubic_filter;
}
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
w2,
width1,
rwidth,
support_w,
wx,
interp_width,
filter_fn,
xmin,
xsize);
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
h2,
height1,
rheight,
support_h,
wy,
interp_height,
filter_fn,
ymin,
ysize);
for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; ++c) {
// interpolate on x-axis for ymin to ymin + ysize
for (int y = 0; y < ysize; y++) {
// copy data into the local buffer and use
// interpolate_aa_single_dim method
for (int x = 0; x < xsize; x++) {
buffer1[x] = idata[n][c][ymin + y][xmin + x];
}
buffer2[y] = static_cast<scalar_t>(
interpolate_aa_single_dim<scalar_t, accscalar_t>(
buffer1, wx, xsize));
}
odata[n][c][h2][w2] = static_cast<scalar_t>(
interpolate_aa_single_dim<scalar_t, accscalar_t>(
buffer2, wy, ysize));
}
}
}
}
template <int interp_size>
static void upsample_gen2d_out_cuda_template(
const Tensor& output,
const Tensor& input,
IntArrayRef output_size,
bool align_corners,
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
// Copied and adapted from
// UpSampleBicubic2d.cu::upsample_bicubic2d_out_cuda_template
TensorArg input_arg{input, "input", 1}, output_arg{output, "output", 2};
checkAllSameGPU("upsample_gen2d_out_cuda", {input_arg, output_arg});
int output_height = output_size[0];
int output_width = output_size[1];
int nbatch = input.size(0);
int channels = input.size(1);
int input_height = input.size(2);
int input_width = input.size(3);
const int num_kernels = output_height * output_width;
const int num_threads = std::min(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "upsample_gen2d_out_frame", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
auto idata = input.packed_accessor64<scalar_t, 4>();
auto odata = output.packed_accessor64<scalar_t, 4>();
const accscalar_t rheight = area_pixel_compute_scale<accscalar_t>(
input_height, output_height, align_corners, scales_h);
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
input_width, output_width, align_corners, scales_w);
// We are using static buffer memory of 256 * sizeof(float) per thread
// to store weights. Size of weights array is
// interp_size = scale * 2 + 1 for bilinear mode
TORCH_CHECK(
rheight < (255 / interp_size),
"Max supported scale factor is 127 (bilinear), 63 (bicubic)");
TORCH_CHECK(
rwidth < (255 / interp_size),
"Max supported scale factor is 127 (bilinear), 63 (bicubic)");
upsample_gen2d_out_frame<scalar_t, accscalar_t, interp_size>
<<<cuda::ATenCeilDiv(num_kernels, num_threads),
num_threads,
0,
stream>>>(
num_kernels, rheight, rwidth, align_corners, idata, odata);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
// Backward (adjoint) operation 1 <- 2 (accumulates)
template <typename scalar_t, typename accscalar_t, int interp_size>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void upsample_gen2d_backward_out_frame(
const int num_elements,
const accscalar_t height_scale,
const accscalar_t width_scale,
const bool align_corners,
PackedTensorAccessor64<scalar_t, 4> idata,
const PackedTensorAccessor64<scalar_t, 4> odata) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
const int batchsize = idata.size(0);
const int channels = idata.size(1);
const int input_height = idata.size(2);
const int input_width = idata.size(3);
const int output_height = odata.size(2);
const int output_width = odata.size(3);
if (index >= num_elements) {
return;
}
const int output_x = index % output_width;
const int output_y = index / output_width;
// special case: output just copy
if (input_height == output_height && input_width == output_width) {
for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; ++c) {
const scalar_t val = odata[n][c][output_y][output_x];
idata[n][c][output_y][output_x] = val;
}
}
return;
}
const accscalar_t support_h = static_cast<accscalar_t>(
(height_scale >= 1.0) ? (interp_size * 0.5) * height_scale
: interp_size * 0.5);
const accscalar_t support_w = static_cast<accscalar_t>(
(width_scale >= 1.0) ? (interp_size * 0.5) * width_scale
: interp_size * 0.5);
const int interp_height = (int)ceilf(support_h) * 2 + 1;
const int interp_width = (int)ceilf(support_w) * 2 + 1;
// Setup local buffers
// TODO: maybe we can specify dynamic shared memory size before calling the
// cuda code, however we should then ensure that device has enough shared
// memory
scalar_t wx[256];
scalar_t wy[256];
// Compute weights
int xmin, xsize, ymin, ysize;
typedef scalar_t (*filter_fn_t)(scalar_t);
filter_fn_t filter_fn;
if (interp_size == 2) {
filter_fn = bilinear_filter;
} else if (interp_size == 4) {
filter_fn = bicubic_filter;
}
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
output_x,
input_width,
width_scale,
support_w,
wx,
interp_width,
filter_fn,
xmin,
xsize);
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
output_y,
input_height,
height_scale,
support_h,
wy,
interp_height,
filter_fn,
ymin,
ysize);
for (int n = 0; n < batchsize; n++) {
for (int c = 0; c < channels; ++c) {
scalar_t out_value = odata[n][c][output_y][output_x];
for (int y = 0; y < ysize; y++) {
for (int x = 0; x < xsize; x++) {
upsample_increment_value_bounded<scalar_t, accscalar_t>(
idata,
n,
c,
input_height,
input_width,
ymin + y,
xmin + x,
wx[x] * wy[y] * out_value);
}
}
}
}
}
template <int interp_size>
static void upsample_gen2d_backward_out_cuda_template(
const Tensor& grad_input,
const Tensor& grad_output_,
IntArrayRef output_size,
IntArrayRef input_size,
bool align_corners,
c10::optional<double> scales_h,
c10::optional<double> scales_w) {
// Copied and adapted from
// UpSampleBicubic2d.cu::upsample_bicubic2d_backward_out_cuda_template
TensorArg grad_input_arg{grad_input, "grad_input", 1},
grad_output_arg{grad_output_, "grad_output_", 2};
checkAllSameGPU(
"upsample_gen2d_backward_out_cuda", {grad_output_arg, grad_input_arg});
int output_height = output_size[0];
int output_width = output_size[1];
int nbatch = input_size[0];
int channels = input_size[1];
int input_height = input_size[2];
int input_width = input_size[3];
Tensor grad_output = grad_output_.contiguous();
grad_input.zero_();
const int num_kernels = output_height * output_width;
const int num_threads = std::min(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(), "upsample_gen2d_backward_out_frame", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
auto idata = grad_input.packed_accessor64<scalar_t, 4>();
auto odata = grad_output.packed_accessor64<scalar_t, 4>();
const accscalar_t rheight = area_pixel_compute_scale<accscalar_t>(
input_height, output_height, align_corners, scales_h);
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
input_width, output_width, align_corners, scales_w);
// We are using static buffer memory of 256 * sizeof(float) per thread
// to store weights. Size of weights array is
// interp_size = scale * 2 + 1 for bilinear mode
TORCH_CHECK(
rheight < (255 / interp_size),
"Max supported scale factor is 127 (bilinear), 63 (bicubic)");
TORCH_CHECK(
rwidth < (255 / interp_size),
"Max supported scale factor is 127 (bilinear), 63 (bicubic)");
upsample_gen2d_backward_out_frame<scalar_t, accscalar_t, interp_size>
<<<cuda::ATenCeilDiv(num_kernels, num_threads),
num_threads,
0,
stream>>>(
num_kernels, rheight, rwidth, align_corners, idata, odata);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
} // namespace internal_upsample
} // namespace native
} // namespace at
namespace vision {
namespace ops {
namespace {
// Copied from "UpSample.h" as we can not use UpSample.h with UpSample.cuh
static std::array<int64_t, 4> upsample_2d_common_check(
at::IntArrayRef input_size,
at::IntArrayRef output_size) {
TORCH_CHECK(
output_size.size() == 2,
"It is expected output_size equals to 2, but got size ",
output_size.size());
TORCH_CHECK(
input_size.size() == 4,
"It is expected input_size equals to 4, but got size ",
input_size.size());
int64_t output_height = output_size[0];
int64_t output_width = output_size[1];
int64_t nbatch = input_size[0];
int64_t channels = input_size[1];
int64_t input_height = input_size[2];
int64_t input_width = input_size[3];
TORCH_CHECK(
input_height > 0 && input_width > 0 && output_height > 0 &&
output_width > 0,
"Input and output sizes should be greater than 0,"
" but got input (H: ",
input_height,
", W: ",
input_width,
") output (H: ",
output_height,
", W: ",
output_width,
")");
return {nbatch, channels, output_height, output_width};
}
template <int interp_size>
at::Tensor interpolate_gen2d_aa_forward_kernel(
const at::Tensor& input,
at::IntArrayRef output_size,
bool align_corners) {
c10::optional<c10::ArrayRef<double>> scale_factors = {};
// Copied from UpSampleBilinear2d.cpp
auto output = at::empty({0}, input.options());
auto osize = at::native::upsample::compute_output_size(
input.sizes(), output_size, scale_factors);
auto scale_h = at::native::upsample_cuda::get_scale_value(scale_factors, 0);
auto scale_w = at::native::upsample_cuda::get_scale_value(scale_factors, 1);
auto full_output_size = upsample_2d_common_check(input.sizes(), osize);
// Allow for empty batch size but not other dimensions
TORCH_CHECK(
input.numel() != 0 ||
c10::multiply_integers(
input.sizes().begin() + 1, input.sizes().end()),
"Non-empty 4D data tensor expected but got a tensor with sizes ",
input.sizes());
output.resize_(full_output_size, input.suggest_memory_format());
at::native::internal_upsample::upsample_gen2d_out_cuda_template<interp_size>(
output,
input,
{full_output_size[2], full_output_size[3]},
align_corners,
scale_h,
scale_w);
return output;
}
template <int interp_size>
at::Tensor interpolate_gen2d_aa_backward_kernel(
const at::Tensor& grad_output,
at::IntArrayRef output_size,
at::IntArrayRef input_size,
bool align_corners) {
c10::optional<c10::ArrayRef<double>> scale_factors = {};
// Copied from UpSampleBicubic2d.cpp::upsample_bicubic2d_backward
auto grad_input = at::empty({0}, grad_output.options());
auto osize = at::native::upsample::compute_output_size(
input_size, output_size, scale_factors);
auto scale_h = at::native::upsample_cuda::get_scale_value(scale_factors, 0);
auto scale_w = at::native::upsample_cuda::get_scale_value(scale_factors, 1);
auto full_output_size = upsample_2d_common_check(input_size, osize);
TORCH_CHECK(
grad_output.dim() == 4,
"Expected grad_output to be a tensor of dimension 4 but got: dimension ",
grad_output.dim());
for (int i = 0; i < 4; ++i) {
TORCH_CHECK(
grad_output.size(i) == full_output_size[i],
"Expected grad_output to have the same shape as output;",
" output.size(",
i,
") = ",
full_output_size[i],
" but got grad_output.size(",
i,
") = ",
grad_output.size(i));
}
grad_input.resize_(input_size, grad_output.suggest_memory_format());
at::native::internal_upsample::upsample_gen2d_backward_out_cuda_template<
interp_size>(
grad_input,
grad_output,
{full_output_size[2], full_output_size[3]},
input_size,
align_corners,
scale_h,
scale_w);
return grad_input;
}
at::Tensor interpolate_bilinear2d_aa_forward_kernel(
const at::Tensor& input,
at::IntArrayRef output_size,
bool align_corners) {
return interpolate_gen2d_aa_forward_kernel<2>(
input, output_size, align_corners);
}
at::Tensor interpolate_bicubic2d_aa_forward_kernel(
const at::Tensor& input,
at::IntArrayRef output_size,
bool align_corners) {
return interpolate_gen2d_aa_forward_kernel<4>(
input, output_size, align_corners);
}
at::Tensor interpolate_bilinear2d_aa_backward_kernel(
const at::Tensor& grad_output,
at::IntArrayRef output_size,
at::IntArrayRef input_size,
bool align_corners) {
return interpolate_gen2d_aa_backward_kernel<2>(
grad_output, output_size, input_size, align_corners);
}
at::Tensor interpolate_bicubic2d_aa_backward_kernel(
const at::Tensor& grad_output,
at::IntArrayRef output_size,
at::IntArrayRef input_size,
bool align_corners) {
return interpolate_gen2d_aa_backward_kernel<4>(
grad_output, output_size, input_size, align_corners);
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_interpolate_bilinear2d_aa"),
TORCH_FN(interpolate_bilinear2d_aa_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic2d_aa"),
TORCH_FN(interpolate_bicubic2d_aa_forward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_interpolate_bilinear2d_aa_backward"),
TORCH_FN(interpolate_bilinear2d_aa_backward_kernel));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic2d_aa_backward"),
TORCH_FN(interpolate_bicubic2d_aa_backward_kernel));
}
} // namespace ops
} // namespace vision
#include "interpolate_aa.h"
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h>
#include <torch/types.h>
namespace vision {
namespace ops {
at::Tensor _interpolate_bilinear2d_aa(
const at::Tensor& input, // Input image
at::IntArrayRef output_size, // Output image size
bool align_corners) // The flag to align corners
{
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_interpolate_bilinear2d_aa", "")
.typed<decltype(_interpolate_bilinear2d_aa)>();
return op.call(input, output_size, align_corners);
}
at::Tensor _interpolate_bicubic_aa(
const at::Tensor& input, // Input image
at::IntArrayRef output_size, // Output image size
bool align_corners) // The flag to align corners
{
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_interpolate_bicubic2d_aa", "")
.typed<decltype(_interpolate_bicubic2d_aa)>();
return op.call(input, output_size, align_corners);
}
namespace detail {
at::Tensor _interpolate_bilinear2d_aa_backward(
const at::Tensor& grad_output,
at::IntArrayRef output_size,
at::IntArrayRef input_size,
bool align_corners) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow(
"torchvision::_interpolate_bilinear2d_aa_backward", "")
.typed<decltype(_interpolate_bilinear2d_aa_backward)>();
return op.call(grad_output, output_size, output_size, align_corners);
}
at::Tensor _interpolate_bicubic2d_aa_backward(
const at::Tensor& grad_output,
at::IntArrayRef output_size,
at::IntArrayRef input_size,
bool align_corners) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow(
"torchvision::_interpolate_bicubic2d_aa_backward", "")
.typed<decltype(_interpolate_bicubic2d_aa_backward)>();
return op.call(grad_output, output_size, output_size, align_corners);
}
} // namespace detail
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_interpolate_bilinear2d_aa(Tensor input, int[] output_size, bool align_corners) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_interpolate_bicubic2d_aa(Tensor input, int[] output_size, bool align_corners) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_interpolate_bilinear2d_aa_backward(Tensor input, int[] output_size, int[] input_size, bool align_corners) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_interpolate_bicubic2d_aa_backward(Tensor input, int[] output_size, int[] input_size, bool align_corners) -> Tensor"));
}
} // namespace ops
} // namespace vision
#pragma once
#include <ATen/ATen.h>
#include "../macros.h"
namespace vision {
namespace ops {
VISION_API at::Tensor _interpolate_bilinear2d_aa(
const at::Tensor& input,
at::IntArrayRef output_size,
bool align_corners = false);
VISION_API at::Tensor _interpolate_bicubic2d_aa(
const at::Tensor& input,
at::IntArrayRef output_size,
bool align_corners = false);
namespace detail {
VISION_API at::Tensor _interpolate_bilinear2d_aa_backward(
const at::Tensor& grad,
at::IntArrayRef output_size,
at::IntArrayRef input_size,
bool align_corners = false);
VISION_API at::Tensor _interpolate_bicubic2d_aa_backward(
const at::Tensor& grad,
at::IntArrayRef output_size,
at::IntArrayRef input_size,
bool align_corners = false);
} // namespace detail
} // namespace ops
} // namespace vision
...@@ -481,13 +481,7 @@ def resize( ...@@ -481,13 +481,7 @@ def resize(
# Define align_corners to avoid warnings # Define align_corners to avoid warnings
align_corners = False if interpolation in ["bilinear", "bicubic"] else None align_corners = False if interpolation in ["bilinear", "bicubic"] else None
if antialias: img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners, antialias=antialias)
if interpolation == "bilinear":
img = torch.ops.torchvision._interpolate_bilinear2d_aa(img, [new_h, new_w], align_corners=False)
elif interpolation == "bicubic":
img = torch.ops.torchvision._interpolate_bicubic2d_aa(img, [new_h, new_w], align_corners=False)
else:
img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners)
if interpolation == "bicubic" and out_dtype == torch.uint8: if interpolation == "bicubic" and out_dtype == torch.uint8:
img = img.clamp(min=0, max=255) img = img.clamp(min=0, max=255)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment