Unverified Commit a8ebd0b3 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

DeformConv2d: SymInt support + meta-implem + opchecks (#8063)

parent 668348ed
......@@ -1019,6 +1019,7 @@ class TestDeformConv:
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("batch_sz", (0, 33))
@pytest.mark.opcheck_only_one()
def test_forward(self, device, contiguous, batch_sz, dtype=None):
dtype = dtype or self.dtype
x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
......@@ -1071,6 +1072,7 @@ class TestDeformConv:
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("batch_sz", (0, 33))
@pytest.mark.opcheck_only_one()
def test_backward(self, device, contiguous, batch_sz):
x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(
device, contiguous, batch_sz, self.dtype
......@@ -1120,6 +1122,7 @@ class TestDeformConv:
@needs_cuda
@pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.opcheck_only_one()
def test_compare_cpu_cuda_grads(self, contiguous):
# Test from https://github.com/pytorch/vision/issues/2598
# Run on CUDA only
......@@ -1154,6 +1157,7 @@ class TestDeformConv:
@needs_cuda
@pytest.mark.parametrize("batch_sz", (0, 33))
@pytest.mark.parametrize("dtype", (torch.float, torch.half))
@pytest.mark.opcheck_only_one()
def test_autocast(self, batch_sz, dtype):
with torch.cuda.amp.autocast():
self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype)
......@@ -1163,6 +1167,15 @@ class TestDeformConv:
torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3))
optests.generate_opcheck_tests(
testcase=TestDeformConv,
namespaces=["torchvision"],
failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"),
additional_decorators=[],
test_utils=OPTESTS,
)
class TestFrozenBNT:
def test_frozenbatchnorm2d_repr(self):
num_features = 32
......
......@@ -172,3 +172,54 @@ def meta_nms(dets, scores, iou_threshold):
ctx = torch._custom_ops.get_ctx()
num_to_keep = ctx.create_unbacked_symint()
return dets.new_empty(num_to_keep, dtype=torch.long)
@register_meta("deform_conv2d")
def meta_deform_conv2d(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dil_h,
dil_w,
n_weight_grps,
n_offset_grps,
use_mask,
):
out_height, out_width = offset.shape[-2:]
out_channels = weight.shape[0]
batch_size = input.shape[0]
return input.new_empty((batch_size, out_channels, out_height, out_width))
@register_meta("_deform_conv2d_backward")
def meta_deform_conv2d_backward(
grad,
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask,
):
grad_input = input.new_empty(input.shape)
grad_weight = weight.new_empty(weight.shape)
grad_offset = offset.new_empty(offset.shape)
grad_mask = mask.new_empty(mask.shape)
grad_bias = bias.new_empty(bias.shape)
return grad_input, grad_weight, grad_offset, grad_mask, grad_bias
......@@ -18,17 +18,17 @@ class DeformConv2dFunction
const torch::autograd::Variable& offset,
const torch::autograd::Variable& mask,
const torch::autograd::Variable& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask) {
at::AutoDispatchBelowADInplaceOrView g;
auto output = deform_conv2d(
auto output = deform_conv2d_symint(
input,
weight,
offset,
......@@ -70,17 +70,17 @@ class DeformConv2dFunction
auto mask = saved[3];
auto bias = saved[4];
auto stride_h = ctx->saved_data["stride_h"].toInt();
auto stride_w = ctx->saved_data["stride_w"].toInt();
auto pad_h = ctx->saved_data["pad_h"].toInt();
auto pad_w = ctx->saved_data["pad_w"].toInt();
auto dilation_h = ctx->saved_data["dilation_h"].toInt();
auto dilation_w = ctx->saved_data["dilation_w"].toInt();
auto groups = ctx->saved_data["groups"].toInt();
auto offset_groups = ctx->saved_data["offset_groups"].toInt();
auto stride_h = ctx->saved_data["stride_h"].toSymInt();
auto stride_w = ctx->saved_data["stride_w"].toSymInt();
auto pad_h = ctx->saved_data["pad_h"].toSymInt();
auto pad_w = ctx->saved_data["pad_w"].toSymInt();
auto dilation_h = ctx->saved_data["dilation_h"].toSymInt();
auto dilation_w = ctx->saved_data["dilation_w"].toSymInt();
auto groups = ctx->saved_data["groups"].toSymInt();
auto offset_groups = ctx->saved_data["offset_groups"].toSymInt();
auto use_mask = ctx->saved_data["use_mask"].toBool();
auto grads = detail::_deform_conv2d_backward(
auto grads = detail::_deform_conv2d_backward_symint(
grad_output[0],
input,
weight,
......@@ -133,17 +133,17 @@ class DeformConv2dBackwardFunction
const torch::autograd::Variable& offset,
const torch::autograd::Variable& mask,
const torch::autograd::Variable& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask) {
at::AutoDispatchBelowADInplaceOrView g;
auto result = detail::_deform_conv2d_backward(
auto result = detail::_deform_conv2d_backward_symint(
grad,
input,
weight,
......@@ -188,14 +188,14 @@ at::Tensor deform_conv2d_autograd(
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask) {
return DeformConv2dFunction::apply(
input,
......@@ -222,14 +222,14 @@ deform_conv2d_backward_autograd(
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask) {
auto result = DeformConv2dBackwardFunction::apply(
grad,
......
......@@ -43,6 +43,42 @@ at::Tensor deform_conv2d(
use_mask);
}
at::Tensor deform_conv2d_symint(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask) {
C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.deform_conv2d.deform_conv2d");
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::deform_conv2d", "")
.typed<decltype(deform_conv2d_symint)>();
return op.call(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask);
}
namespace detail {
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
......@@ -84,13 +120,52 @@ _deform_conv2d_backward(
use_mask);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_deform_conv2d_backward_symint(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_deform_conv2d_backward", "")
.typed<decltype(_deform_conv2d_backward_symint)>();
return op.call(
grad,
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask);
}
} // namespace detail
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor"));
"torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"));
"torchvision::_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"));
}
} // namespace ops
......
......@@ -22,6 +22,22 @@ VISION_API at::Tensor deform_conv2d(
int64_t offset_groups,
bool use_mask);
VISION_API at::Tensor deform_conv2d_symint(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask);
namespace detail {
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
......@@ -42,6 +58,24 @@ _deform_conv2d_backward(
int64_t offset_groups,
bool use_mask);
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_deform_conv2d_backward_symint(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask);
} // namespace detail
} // namespace ops
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment