Unverified Commit 3c7321c0 authored by pc's avatar pc Committed by GitHub
Browse files

[Fix] Fix pybind exporting bug in fused_bias_leakyrelu and upfirdn2d (#1005)

* fix export bug in pybind

* fix type bug in fused_bias_leakyrelu backward
parent ab973df6
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "roi_pool_pytorch.h" #include "roi_pool_pytorch.h"
using namespace parrots; using namespace parrots;
#ifdef MMCV_WITH_CUDA
void roi_pool_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr, void roi_pool_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr,
const OperatorBase::in_list_t& ins, const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) { OperatorBase::out_list_t& outs) {
...@@ -62,3 +63,4 @@ PARROTS_EXTENSION_REGISTER(roi_pool_backward) ...@@ -62,3 +63,4 @@ PARROTS_EXTENSION_REGISTER(roi_pool_backward)
.output(1) .output(1)
.apply(roi_pool_backward_cuda_parrots) .apply(roi_pool_backward_cuda_parrots)
.done(); .done();
#endif
...@@ -214,9 +214,14 @@ void roi_align_rotated_backward(Tensor grad_output, Tensor rois, ...@@ -214,9 +214,14 @@ void roi_align_rotated_backward(Tensor grad_output, Tensor rois,
int sample_num, bool aligned, bool clockwise); int sample_num, bool aligned, bool clockwise);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
py::arg("down_y"), py::arg("pad_x0"), py::arg("pad_x1"),
py::arg("pad_y0"), py::arg("pad_y1"));
m.def("fused_bias_leakyrelu", &fused_bias_leakyrelu, m.def("fused_bias_leakyrelu", &fused_bias_leakyrelu,
"fused_bias_leakyrelu (CUDA)"); "fused_bias_leakyrelu (CUDA)", py::arg("input"), py::arg("bias"),
py::arg("empty"), py::arg("act"), py::arg("grad"), py::arg("alpha"),
py::arg("scale"));
m.def("get_compiler_version", &get_compiler_version, "get_compiler_version"); m.def("get_compiler_version", &get_compiler_version, "get_compiler_version");
m.def("get_compiling_cuda_version", &get_compiling_cuda_version, m.def("get_compiling_cuda_version", &get_compiling_cuda_version,
"get_compiling_cuda_version"); "get_compiling_cuda_version");
......
...@@ -52,7 +52,7 @@ class FusedBiasLeakyReLUFunctionBackward(Function): ...@@ -52,7 +52,7 @@ class FusedBiasLeakyReLUFunctionBackward(Function):
# which is similar with the first order deviation in implementation. # which is similar with the first order deviation in implementation.
gradgrad_out = ext_module.fused_bias_leakyrelu( gradgrad_out = ext_module.fused_bias_leakyrelu(
gradgrad_input, gradgrad_input,
gradgrad_bias, gradgrad_bias.to(out.dtype),
out, out,
act=3, act=3,
grad=1, grad=1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment