Unverified Commit 8f2e5c90 authored by Edward Z. Yang's avatar Edward Z. Yang Committed by GitHub
Browse files

SymIntify roi_align (#7448)


Signed-off-by: default avatarEdward Z. Yang <ezyang@meta.com>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent db3ead16
...@@ -3,7 +3,7 @@ import warnings ...@@ -3,7 +3,7 @@ import warnings
from modulefinder import Module from modulefinder import Module
import torch import torch
from torchvision import datasets, io, models, ops, transforms, utils from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils
from .extension import _HAS_OPS from .extension import _HAS_OPS
......
import torch
import torch.library
# Ensure that torch.ops.torchvision is visible
import torchvision.extension # noqa: F401
from torch._prims_common import check
_meta_lib = torch.library.Library("torchvision", "IMPL", "Meta")
vision = torch.ops.torchvision
def register_meta(op):
def wrapper(fn):
_meta_lib.impl(op, fn)
return fn
return wrapper
@register_meta(vision.roi_align.default)
def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
check(
input.dtype == rois.dtype,
lambda: (
"Expected tensor for input to have the same type as tensor for rois; "
f"but type {input.dtype} does not equal {rois.dtype}"
),
)
num_rois = rois.size(0)
_, channels, height, width = input.size()
return input.new_empty((num_rois, channels, pooled_height, pooled_width))
@register_meta(vision._roi_align_backward.default)
def meta_roi_align_backward(
grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned
):
check(
grad.dtype == rois.dtype,
lambda: (
"Expected tensor for grad to have the same type as tensor for rois; "
f"but type {grad.dtype} does not equal {rois.dtype}"
),
)
return grad.new_empty((batch_size, channels, height, width))
...@@ -15,8 +15,8 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> { ...@@ -15,8 +15,8 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
const torch::autograd::Variable& input, const torch::autograd::Variable& input,
const torch::autograd::Variable& rois, const torch::autograd::Variable& rois,
double spatial_scale, double spatial_scale,
int64_t pooled_height, c10::SymInt pooled_height,
int64_t pooled_width, c10::SymInt pooled_width,
int64_t sampling_ratio, int64_t sampling_ratio,
bool aligned) { bool aligned) {
ctx->saved_data["spatial_scale"] = spatial_scale; ctx->saved_data["spatial_scale"] = spatial_scale;
...@@ -24,10 +24,10 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> { ...@@ -24,10 +24,10 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
ctx->saved_data["pooled_width"] = pooled_width; ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["sampling_ratio"] = sampling_ratio; ctx->saved_data["sampling_ratio"] = sampling_ratio;
ctx->saved_data["aligned"] = aligned; ctx->saved_data["aligned"] = aligned;
ctx->saved_data["input_shape"] = input.sizes(); ctx->saved_data["input_shape"] = input.sym_sizes();
ctx->save_for_backward({rois}); ctx->save_for_backward({rois});
at::AutoDispatchBelowADInplaceOrView g; at::AutoDispatchBelowADInplaceOrView g;
auto result = roi_align( auto result = roi_align_symint(
input, input,
rois, rois,
spatial_scale, spatial_scale,
...@@ -44,17 +44,17 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> { ...@@ -44,17 +44,17 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
// Use data saved in forward // Use data saved in forward
auto saved = ctx->get_saved_variables(); auto saved = ctx->get_saved_variables();
auto rois = saved[0]; auto rois = saved[0];
auto input_shape = ctx->saved_data["input_shape"].toIntList(); auto input_shape = ctx->saved_data["input_shape"].toList();
auto grad_in = detail::_roi_align_backward( auto grad_in = detail::_roi_align_backward_symint(
grad_output[0], grad_output[0],
rois, rois,
ctx->saved_data["spatial_scale"].toDouble(), ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(), ctx->saved_data["pooled_height"].toSymInt(),
ctx->saved_data["pooled_width"].toInt(), ctx->saved_data["pooled_width"].toSymInt(),
input_shape[0], input_shape[0].get().toSymInt(),
input_shape[1], input_shape[1].get().toSymInt(),
input_shape[2], input_shape[2].get().toSymInt(),
input_shape[3], input_shape[3].get().toSymInt(),
ctx->saved_data["sampling_ratio"].toInt(), ctx->saved_data["sampling_ratio"].toInt(),
ctx->saved_data["aligned"].toBool()); ctx->saved_data["aligned"].toBool());
return { return {
...@@ -77,16 +77,16 @@ class ROIAlignBackwardFunction ...@@ -77,16 +77,16 @@ class ROIAlignBackwardFunction
const torch::autograd::Variable& grad, const torch::autograd::Variable& grad,
const torch::autograd::Variable& rois, const torch::autograd::Variable& rois,
double spatial_scale, double spatial_scale,
int64_t pooled_height, c10::SymInt pooled_height,
int64_t pooled_width, c10::SymInt pooled_width,
int64_t batch_size, c10::SymInt batch_size,
int64_t channels, c10::SymInt channels,
int64_t height, c10::SymInt height,
int64_t width, c10::SymInt width,
int64_t sampling_ratio, int64_t sampling_ratio,
bool aligned) { bool aligned) {
at::AutoDispatchBelowADInplaceOrView g; at::AutoDispatchBelowADInplaceOrView g;
auto result = detail::_roi_align_backward( auto result = detail::_roi_align_backward_symint(
grad, grad,
rois, rois,
spatial_scale, spatial_scale,
...@@ -112,8 +112,8 @@ at::Tensor roi_align_autograd( ...@@ -112,8 +112,8 @@ at::Tensor roi_align_autograd(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& rois, const at::Tensor& rois,
double spatial_scale, double spatial_scale,
int64_t pooled_height, c10::SymInt pooled_height,
int64_t pooled_width, c10::SymInt pooled_width,
int64_t sampling_ratio, int64_t sampling_ratio,
bool aligned) { bool aligned) {
return ROIAlignFunction::apply( return ROIAlignFunction::apply(
...@@ -130,12 +130,12 @@ at::Tensor roi_align_backward_autograd( ...@@ -130,12 +130,12 @@ at::Tensor roi_align_backward_autograd(
const at::Tensor& grad, const at::Tensor& grad,
const at::Tensor& rois, const at::Tensor& rois,
double spatial_scale, double spatial_scale,
int64_t pooled_height, c10::SymInt pooled_height,
int64_t pooled_width, c10::SymInt pooled_width,
int64_t batch_size, c10::SymInt batch_size,
int64_t channels, c10::SymInt channels,
int64_t height, c10::SymInt height,
int64_t width, c10::SymInt width,
int64_t sampling_ratio, int64_t sampling_ratio,
bool aligned) { bool aligned) {
return ROIAlignBackwardFunction::apply( return ROIAlignBackwardFunction::apply(
......
...@@ -32,6 +32,31 @@ at::Tensor roi_align( ...@@ -32,6 +32,31 @@ at::Tensor roi_align(
aligned); aligned);
} }
at::Tensor roi_align_symint(
const at::Tensor& input, // Input feature map.
const at::Tensor& rois, // List of ROIs to pool over.
double spatial_scale, // The scale of the image features. ROIs will be
// scaled to this.
c10::SymInt pooled_height, // The height of the pooled feature map.
c10::SymInt pooled_width, // The width of the pooled feature
int64_t sampling_ratio, // The number of points to sample in each bin
bool aligned) // The flag for pixel shift
// along each axis.
{
C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_align.roi_align");
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::roi_align", "")
.typed<decltype(roi_align_symint)>();
return op.call(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
aligned);
}
namespace detail { namespace detail {
at::Tensor _roi_align_backward( at::Tensor _roi_align_backward(
...@@ -64,13 +89,43 @@ at::Tensor _roi_align_backward( ...@@ -64,13 +89,43 @@ at::Tensor _roi_align_backward(
aligned); aligned);
} }
at::Tensor _roi_align_backward_symint(
const at::Tensor& grad,
const at::Tensor& rois,
double spatial_scale,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
c10::SymInt batch_size,
c10::SymInt channels,
c10::SymInt height,
c10::SymInt width,
int64_t sampling_ratio,
bool aligned) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_roi_align_backward", "")
.typed<decltype(_roi_align_backward_symint)>();
return op.call(
grad,
rois,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width,
sampling_ratio,
aligned);
}
} // namespace detail } // namespace detail
TORCH_LIBRARY_FRAGMENT(torchvision, m) { TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(TORCH_SELECTIVE_SCHEMA( m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor")); "torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, bool aligned) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA( m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor")); "torchvision::_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width, int sampling_ratio, bool aligned) -> Tensor"));
} }
} // namespace ops } // namespace ops
......
...@@ -15,6 +15,15 @@ VISION_API at::Tensor roi_align( ...@@ -15,6 +15,15 @@ VISION_API at::Tensor roi_align(
int64_t sampling_ratio, int64_t sampling_ratio,
bool aligned); bool aligned);
VISION_API at::Tensor roi_align_symint(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
int64_t sampling_ratio,
bool aligned);
namespace detail { namespace detail {
at::Tensor _roi_align_backward( at::Tensor _roi_align_backward(
...@@ -30,6 +39,19 @@ at::Tensor _roi_align_backward( ...@@ -30,6 +39,19 @@ at::Tensor _roi_align_backward(
int64_t sampling_ratio, int64_t sampling_ratio,
bool aligned); bool aligned);
at::Tensor _roi_align_backward_symint(
const at::Tensor& grad,
const at::Tensor& rois,
double spatial_scale,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
c10::SymInt batch_size,
c10::SymInt channels,
c10::SymInt height,
c10::SymInt width,
int64_t sampling_ratio,
bool aligned);
} // namespace detail } // namespace detail
} // namespace ops } // namespace ops
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment