Unverified Commit a8ebd0b3 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

DeformConv2d: SymInt support + meta-implem + opchecks (#8063)

parent 668348ed
...@@ -1019,6 +1019,7 @@ class TestDeformConv: ...@@ -1019,6 +1019,7 @@ class TestDeformConv:
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("batch_sz", (0, 33)) @pytest.mark.parametrize("batch_sz", (0, 33))
@pytest.mark.opcheck_only_one()
def test_forward(self, device, contiguous, batch_sz, dtype=None): def test_forward(self, device, contiguous, batch_sz, dtype=None):
dtype = dtype or self.dtype dtype = dtype or self.dtype
x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype) x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
...@@ -1071,6 +1072,7 @@ class TestDeformConv: ...@@ -1071,6 +1072,7 @@ class TestDeformConv:
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("batch_sz", (0, 33)) @pytest.mark.parametrize("batch_sz", (0, 33))
@pytest.mark.opcheck_only_one()
def test_backward(self, device, contiguous, batch_sz): def test_backward(self, device, contiguous, batch_sz):
x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args( x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(
device, contiguous, batch_sz, self.dtype device, contiguous, batch_sz, self.dtype
...@@ -1120,6 +1122,7 @@ class TestDeformConv: ...@@ -1120,6 +1122,7 @@ class TestDeformConv:
@needs_cuda @needs_cuda
@pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.opcheck_only_one()
def test_compare_cpu_cuda_grads(self, contiguous): def test_compare_cpu_cuda_grads(self, contiguous):
# Test from https://github.com/pytorch/vision/issues/2598 # Test from https://github.com/pytorch/vision/issues/2598
# Run on CUDA only # Run on CUDA only
...@@ -1154,6 +1157,7 @@ class TestDeformConv: ...@@ -1154,6 +1157,7 @@ class TestDeformConv:
@needs_cuda @needs_cuda
@pytest.mark.parametrize("batch_sz", (0, 33)) @pytest.mark.parametrize("batch_sz", (0, 33))
@pytest.mark.parametrize("dtype", (torch.float, torch.half)) @pytest.mark.parametrize("dtype", (torch.float, torch.half))
@pytest.mark.opcheck_only_one()
def test_autocast(self, batch_sz, dtype): def test_autocast(self, batch_sz, dtype):
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype) self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype)
...@@ -1163,6 +1167,15 @@ class TestDeformConv: ...@@ -1163,6 +1167,15 @@ class TestDeformConv:
torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3)) torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3))
optests.generate_opcheck_tests(
testcase=TestDeformConv,
namespaces=["torchvision"],
failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"),
additional_decorators=[],
test_utils=OPTESTS,
)
class TestFrozenBNT: class TestFrozenBNT:
def test_frozenbatchnorm2d_repr(self): def test_frozenbatchnorm2d_repr(self):
num_features = 32 num_features = 32
......
...@@ -172,3 +172,54 @@ def meta_nms(dets, scores, iou_threshold): ...@@ -172,3 +172,54 @@ def meta_nms(dets, scores, iou_threshold):
ctx = torch._custom_ops.get_ctx() ctx = torch._custom_ops.get_ctx()
num_to_keep = ctx.create_unbacked_symint() num_to_keep = ctx.create_unbacked_symint()
return dets.new_empty(num_to_keep, dtype=torch.long) return dets.new_empty(num_to_keep, dtype=torch.long)
@register_meta("deform_conv2d")
def meta_deform_conv2d(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dil_h,
dil_w,
n_weight_grps,
n_offset_grps,
use_mask,
):
out_height, out_width = offset.shape[-2:]
out_channels = weight.shape[0]
batch_size = input.shape[0]
return input.new_empty((batch_size, out_channels, out_height, out_width))
@register_meta("_deform_conv2d_backward")
def meta_deform_conv2d_backward(
grad,
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask,
):
grad_input = input.new_empty(input.shape)
grad_weight = weight.new_empty(weight.shape)
grad_offset = offset.new_empty(offset.shape)
grad_mask = mask.new_empty(mask.shape)
grad_bias = bias.new_empty(bias.shape)
return grad_input, grad_weight, grad_offset, grad_mask, grad_bias
...@@ -18,17 +18,17 @@ class DeformConv2dFunction ...@@ -18,17 +18,17 @@ class DeformConv2dFunction
const torch::autograd::Variable& offset, const torch::autograd::Variable& offset,
const torch::autograd::Variable& mask, const torch::autograd::Variable& mask,
const torch::autograd::Variable& bias, const torch::autograd::Variable& bias,
int64_t stride_h, c10::SymInt stride_h,
int64_t stride_w, c10::SymInt stride_w,
int64_t pad_h, c10::SymInt pad_h,
int64_t pad_w, c10::SymInt pad_w,
int64_t dilation_h, c10::SymInt dilation_h,
int64_t dilation_w, c10::SymInt dilation_w,
int64_t groups, c10::SymInt groups,
int64_t offset_groups, c10::SymInt offset_groups,
bool use_mask) { bool use_mask) {
at::AutoDispatchBelowADInplaceOrView g; at::AutoDispatchBelowADInplaceOrView g;
auto output = deform_conv2d( auto output = deform_conv2d_symint(
input, input,
weight, weight,
offset, offset,
...@@ -70,17 +70,17 @@ class DeformConv2dFunction ...@@ -70,17 +70,17 @@ class DeformConv2dFunction
auto mask = saved[3]; auto mask = saved[3];
auto bias = saved[4]; auto bias = saved[4];
auto stride_h = ctx->saved_data["stride_h"].toInt(); auto stride_h = ctx->saved_data["stride_h"].toSymInt();
auto stride_w = ctx->saved_data["stride_w"].toInt(); auto stride_w = ctx->saved_data["stride_w"].toSymInt();
auto pad_h = ctx->saved_data["pad_h"].toInt(); auto pad_h = ctx->saved_data["pad_h"].toSymInt();
auto pad_w = ctx->saved_data["pad_w"].toInt(); auto pad_w = ctx->saved_data["pad_w"].toSymInt();
auto dilation_h = ctx->saved_data["dilation_h"].toInt(); auto dilation_h = ctx->saved_data["dilation_h"].toSymInt();
auto dilation_w = ctx->saved_data["dilation_w"].toInt(); auto dilation_w = ctx->saved_data["dilation_w"].toSymInt();
auto groups = ctx->saved_data["groups"].toInt(); auto groups = ctx->saved_data["groups"].toSymInt();
auto offset_groups = ctx->saved_data["offset_groups"].toInt(); auto offset_groups = ctx->saved_data["offset_groups"].toSymInt();
auto use_mask = ctx->saved_data["use_mask"].toBool(); auto use_mask = ctx->saved_data["use_mask"].toBool();
auto grads = detail::_deform_conv2d_backward( auto grads = detail::_deform_conv2d_backward_symint(
grad_output[0], grad_output[0],
input, input,
weight, weight,
...@@ -133,17 +133,17 @@ class DeformConv2dBackwardFunction ...@@ -133,17 +133,17 @@ class DeformConv2dBackwardFunction
const torch::autograd::Variable& offset, const torch::autograd::Variable& offset,
const torch::autograd::Variable& mask, const torch::autograd::Variable& mask,
const torch::autograd::Variable& bias, const torch::autograd::Variable& bias,
int64_t stride_h, c10::SymInt stride_h,
int64_t stride_w, c10::SymInt stride_w,
int64_t pad_h, c10::SymInt pad_h,
int64_t pad_w, c10::SymInt pad_w,
int64_t dilation_h, c10::SymInt dilation_h,
int64_t dilation_w, c10::SymInt dilation_w,
int64_t groups, c10::SymInt groups,
int64_t offset_groups, c10::SymInt offset_groups,
bool use_mask) { bool use_mask) {
at::AutoDispatchBelowADInplaceOrView g; at::AutoDispatchBelowADInplaceOrView g;
auto result = detail::_deform_conv2d_backward( auto result = detail::_deform_conv2d_backward_symint(
grad, grad,
input, input,
weight, weight,
...@@ -188,14 +188,14 @@ at::Tensor deform_conv2d_autograd( ...@@ -188,14 +188,14 @@ at::Tensor deform_conv2d_autograd(
const at::Tensor& offset, const at::Tensor& offset,
const at::Tensor& mask, const at::Tensor& mask,
const at::Tensor& bias, const at::Tensor& bias,
int64_t stride_h, c10::SymInt stride_h,
int64_t stride_w, c10::SymInt stride_w,
int64_t pad_h, c10::SymInt pad_h,
int64_t pad_w, c10::SymInt pad_w,
int64_t dilation_h, c10::SymInt dilation_h,
int64_t dilation_w, c10::SymInt dilation_w,
int64_t groups, c10::SymInt groups,
int64_t offset_groups, c10::SymInt offset_groups,
bool use_mask) { bool use_mask) {
return DeformConv2dFunction::apply( return DeformConv2dFunction::apply(
input, input,
...@@ -222,14 +222,14 @@ deform_conv2d_backward_autograd( ...@@ -222,14 +222,14 @@ deform_conv2d_backward_autograd(
const at::Tensor& offset, const at::Tensor& offset,
const at::Tensor& mask, const at::Tensor& mask,
const at::Tensor& bias, const at::Tensor& bias,
int64_t stride_h, c10::SymInt stride_h,
int64_t stride_w, c10::SymInt stride_w,
int64_t pad_h, c10::SymInt pad_h,
int64_t pad_w, c10::SymInt pad_w,
int64_t dilation_h, c10::SymInt dilation_h,
int64_t dilation_w, c10::SymInt dilation_w,
int64_t groups, c10::SymInt groups,
int64_t offset_groups, c10::SymInt offset_groups,
bool use_mask) { bool use_mask) {
auto result = DeformConv2dBackwardFunction::apply( auto result = DeformConv2dBackwardFunction::apply(
grad, grad,
......
...@@ -43,6 +43,42 @@ at::Tensor deform_conv2d( ...@@ -43,6 +43,42 @@ at::Tensor deform_conv2d(
use_mask); use_mask);
} }
at::Tensor deform_conv2d_symint(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask) {
C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.deform_conv2d.deform_conv2d");
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::deform_conv2d", "")
.typed<decltype(deform_conv2d_symint)>();
return op.call(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask);
}
namespace detail { namespace detail {
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
...@@ -84,13 +120,52 @@ _deform_conv2d_backward( ...@@ -84,13 +120,52 @@ _deform_conv2d_backward(
use_mask); use_mask);
} }
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_deform_conv2d_backward_symint(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_deform_conv2d_backward", "")
.typed<decltype(_deform_conv2d_backward_symint)>();
return op.call(
grad,
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask);
}
} // namespace detail } // namespace detail
TORCH_LIBRARY_FRAGMENT(torchvision, m) { TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(TORCH_SELECTIVE_SCHEMA( m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor")); "torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA( m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)")); "torchvision::_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"));
} }
} // namespace ops } // namespace ops
......
...@@ -22,6 +22,22 @@ VISION_API at::Tensor deform_conv2d( ...@@ -22,6 +22,22 @@ VISION_API at::Tensor deform_conv2d(
int64_t offset_groups, int64_t offset_groups,
bool use_mask); bool use_mask);
VISION_API at::Tensor deform_conv2d_symint(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask);
namespace detail { namespace detail {
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
...@@ -42,6 +58,24 @@ _deform_conv2d_backward( ...@@ -42,6 +58,24 @@ _deform_conv2d_backward(
int64_t offset_groups, int64_t offset_groups,
bool use_mask); bool use_mask);
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_deform_conv2d_backward_symint(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask);
} // namespace detail } // namespace detail
} // namespace ops } // namespace ops
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment