Unverified Commit 2708fac6 authored by zhouchenyang's avatar zhouchenyang Committed by GitHub
Browse files

[Feature] Support RoiPool with Cambricon MLU backend (#2073)

* [Feature] Support RoiPool with cambricon MLU backend

* [Docs] Update ops.md
parent d71d067d
...@@ -41,7 +41,7 @@ We implement common ops used in detection, segmentation, etc. ...@@ -41,7 +41,7 @@ We implement common ops used in detection, segmentation, etc.
| PSAMask | √ | √ | √ | | PSAMask | √ | √ | √ |
| RotatedFeatureAlign | √ | √ | | | RotatedFeatureAlign | √ | √ | |
| RoIPointPool3d | | √ | | | RoIPointPool3d | | √ | |
| RoIPool | | √ | | | RoIPool | | √ | |
| RoIAlignRotated | √ | √ | √ | | RoIAlignRotated | √ | √ | √ |
| RiRoIAlignRotated | | √ | | | RiRoIAlignRotated | | √ | |
| RoIAlign | √ | √ | √ | | RoIAlign | √ | √ | √ |
......
...@@ -41,7 +41,7 @@ MMCV 提供了检测、分割等任务中常用的算子 ...@@ -41,7 +41,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| PSAMask | √ | √ | √ | | PSAMask | √ | √ | √ |
| RotatedFeatureAlign | √ | √ | | | RotatedFeatureAlign | √ | √ | |
| RoIPointPool3d | | √ | | | RoIPointPool3d | | √ | |
| RoIPool | | √ | | | RoIPool | | √ | |
| RoIAlignRotated | √ | √ | √ | | RoIAlignRotated | √ | √ | √ |
| RiRoIAlignRotated | | √ | | | RiRoIAlignRotated | | √ | |
| RoIAlign | √ | √ | √ | | RoIAlign | √ | √ | √ |
......
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/ *************************************************************************/
#ifndef UTILS_H_ #ifndef COMMON_MLU_HELPER_HPP_
#define UTILS_H_ #define COMMON_MLU_HELPER_HPP_
#define NFU_ALIGN_SIZE 128 // Byte #define NFU_ALIGN_SIZE 128 // Byte
#define REM_FOR_STACK (128 * 1024) // 128KB reserved for cncc #define REM_FOR_STACK (128 * 1024) // 128KB reserved for cncc
...@@ -35,4 +35,156 @@ ...@@ -35,4 +35,156 @@
#define CEIL_ALIGN(x, y) (((x) + (y)-1) / (y) * (y)) #define CEIL_ALIGN(x, y) (((x) + (y)-1) / (y) * (y))
#endif // UTILS_H_ /*!
* @brief Converts int32 to float32 data type.
*
* @param[out] dst
* Pointer to NRAM that stores int32 type data.
* @param[in,out] dst_addition
* Pointer to NRAM as the workspace of dst, which has the same size as dst.
* It allows empty pointer on MLU300 series.
* @param[in] src
* Pointer to NRAM that stores float32 type data.
* @param[in,out] src_addition
* Pointer to NRAM as the workspace of src, which has a size of 128 Bytes.
* It allows empty pointer on MLU300 series.
* @param[in] src_count
* The count of elements in src.
*/
__mlu_func__ void convertInt2Float(float *dst, float *dst_addition, int *src,
float *src_addition, const int src_count) {
#if __BANG_ARCH__ >= 300
__bang_int2float((float *)dst, (int32_t *)src, src_count, 0);
#else
// get sign bit
const float move_23bit = 8388608.0;
// 0x80000000 = 1,000000000,0000000000000000000000000000
__nramset((unsigned *)src_addition, NFU_ALIGN_SIZE / sizeof(float),
0x80000000);
__bang_cycle_band((char *)dst_addition, (char *)src, (char *)src_addition,
src_count * sizeof(float), NFU_ALIGN_SIZE);
// get 1 or 0 from sign bit
// judg is Odd
__nramset((unsigned *)src_addition, NFU_ALIGN_SIZE / sizeof(float),
0x00000001);
__bang_cycle_bor((char *)dst_addition, (char *)dst_addition,
(char *)src_addition, src_count * sizeof(float),
NFU_ALIGN_SIZE);
__nramset((unsigned *)src_addition, NFU_ALIGN_SIZE / sizeof(float),
0x80000001);
__bang_cycle_eq(dst_addition, dst_addition, src_addition, src_count,
NFU_ALIGN_SIZE / sizeof(float));
// minus xor, positive num invariant
__nramset((unsigned *)src_addition, NFU_ALIGN_SIZE / sizeof(float),
0xffffffff);
__bang_cycle_mul(dst, dst_addition, src_addition, src_count,
NFU_ALIGN_SIZE / sizeof(float));
__bang_bxor((char *)dst, (char *)src, (char *)dst, src_count * sizeof(float));
// convert int32 to float32
__nramset((unsigned *)src_addition, NFU_ALIGN_SIZE / sizeof(float), 0x7fffff);
__bang_cycle_band((char *)dst, (char *)dst, (char *)src_addition,
src_count * sizeof(float), NFU_ALIGN_SIZE);
__nramset((unsigned *)src_addition, NFU_ALIGN_SIZE / sizeof(float),
0x4b000000);
__bang_cycle_bor((char *)dst, (char *)dst, (char *)src_addition,
src_count * sizeof(float), NFU_ALIGN_SIZE);
__bang_sub_const(dst, dst, move_23bit, src_count);
// add one
__bang_add(dst, dst, dst_addition, src_count);
// set sign for float32
__nramset((unsigned *)src_addition, NFU_ALIGN_SIZE / sizeof(float),
0xffffffff);
__bang_cycle_mul(dst_addition, dst_addition, src_addition, src_count,
NFU_ALIGN_SIZE / sizeof(float));
__nramset((unsigned *)src_addition, NFU_ALIGN_SIZE / sizeof(float),
0x00000001);
__bang_cycle_add(dst_addition, dst_addition, src_addition, src_count,
NFU_ALIGN_SIZE / sizeof(float));
__nramset((unsigned *)src_addition, NFU_ALIGN_SIZE / sizeof(float),
0x80000000);
__bang_cycle_band((char *)dst_addition, (char *)dst_addition,
(char *)src_addition, src_count * 4, 128);
__bang_bor((char *)dst, (char *)dst, (char *)dst_addition, src_count * 4);
#endif // __BANG_ARCH__ >= 300
}
/*!
* @brief Converts float32 to int32 data type with to_zero round mode.
*
* @param[out] dst
* Pointer to NRAM that stores float32 type data.
* @param[in,out] dst_addition
* Pointer to NRAM as the workspace of dst, which has the same size as dst.
* It allows empty pointer on MLU300 series.
* @param[in] src
* Pointer to NRAM that stores int32 type data.
* @param[in,out] src_addition
* Pointer to NRAM as the workspace of src, which has a size of 128 Bytes.
* It allows empty pointer on MLU300 series.
* @param[in] src_count
* The count of elements in src.
*/
__mlu_func__ void convertFloat2Int(int *dst, float *dst_addition, float *src,
float *src_addition, const int src_count) {
#if __BANG_ARCH__ >= 300
__bang_float2int_tz((int32_t *)dst, (float *)src, src_count, 0);
#else
// sign ===> src_addition
// dst=-1.0 : when src[i] is a negative number
// dst=+1.0 : when src[i] is a positive number
const int floatDchar = sizeof(float) / sizeof(char);
__bang_active_sign((float *)dst, src, src_count);
// dst_addition = abs(src)
__bang_mul(dst_addition, src, (float *)dst, src_count);
// if dst_addition < 1.0 , then src_addition + 1, to fix add error.
__nramset((float *)src_addition, NFU_ALIGN_SIZE / sizeof(float), 1.0f);
__bang_cycle_lt(dst_addition, dst_addition, (float *)src_addition, src_count,
NFU_ALIGN_SIZE / sizeof(float));
__bang_add_tz((float *)dst, (float *)dst, (float *)dst_addition, src_count);
__nramset((unsigned *)src_addition, NFU_ALIGN_SIZE / sizeof(float),
0xbf800000);
// set negative flag -1.0 = 0xbf80000
__bang_cycle_eq(
(float *)dst, (float *)dst, (float *)src_addition, src_count,
NFU_ALIGN_SIZE / sizeof(float)); // to mark all src in [x<-1.0]
__bang_active_abs(dst_addition, src, src_count);
__nramset((float *)src_addition, NFU_ALIGN_SIZE / sizeof(float), 8388608.0f);
// mask shift move 23
__bang_cycle_add_tz(
dst_addition, dst_addition, src_addition, src_count,
NFU_ALIGN_SIZE / sizeof(float)); // right shift move 23bit
// two`s complement for negatibe
// dst=1.0 , when src <-1.0
// dst=0.0 , when src >=-1.0
__bang_sub(dst_addition, dst_addition, (float *)dst, src_count);
// to fix max value
// 0 1001 0110 111 1111 1111 1111 1111 1111 <=> 0xcb7fffff <=> 16777215.0,
// means max value.
__bang_mul_const((float *)dst, (float *)dst, 16777215.0, src_count);
__bang_bxor((char *)dst_addition, (char *)dst_addition, (char *)dst,
src_count * floatDchar);
// get low 23bit
__nramset((unsigned *)src_addition, NFU_ALIGN_SIZE / sizeof(float),
(unsigned)0x007fffff);
// mask low 23bit is 1
__bang_cycle_band((char *)dst_addition, (char *)dst_addition,
(char *)src_addition, src_count * floatDchar,
NFU_ALIGN_SIZE / sizeof(char));
// set 9 high bit ===> dst
// -2.0 <=> 0xc0000000 <=> 1100 0000 0000 0000 0000 0000 0000 0000
// 1.0 <=> 0x3f800000 <=> 0011 1111 1000 0000 0000 0000 0000 0000
__nramset(src_addition, NFU_ALIGN_SIZE / sizeof(float), 0x3f800000);
__bang_cycle_and((float *)dst, (float *)dst, src_addition, src_count,
NFU_ALIGN_SIZE / sizeof(float));
// src or dst_addition
__bang_bor((char *)dst_addition, (char *)dst, (char *)dst_addition,
src_count * floatDchar);
__bang_mul_const((float *)dst, (float *)dst, -2.0, src_count);
__bang_bor((char *)dst, (char *)dst, (char *)dst_addition,
src_count * floatDchar);
#endif // __BANG_ARCH__ >= 300
}
#endif // COMMON_MLU_HELPER_HPP_
This diff is collapsed.
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void KernelRoiPoolForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, cnrtDataType_t data_type,
const void *input_data, const void *input_rois,
const int batch, const int channels, const int height,
const int width, const int pooled_height,
const int pooled_width, const int rois_num,
const float spatial_scale, void *output_data,
int *argmax);
void KernelRoiPoolBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type,
cnrtQueue_t queue, cnrtDataType_t k_dtype,
const void *grad_output_ptr, const void *rois_ptr,
const int *argmax_ptr, void *grad_input_ptr,
const int box_num, const int pooled_height,
const int pooled_width, const int channels,
const int batch, const int height, const int width,
const float spatial_scale);
// policy function for forward
static void policyFuncForward(const int bin_num, cnrtDim3_t *k_dim,
cnrtFunctionType_t *k_type) {
auto core_num = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
auto cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
*k_type = CNRT_FUNC_TYPE_UNION1;
k_dim->x = core_num;
unsigned int use_cluster = bin_num / core_num + (bin_num % core_num > 0);
k_dim->y = use_cluster > cluster_num ? cluster_num : use_cluster;
k_dim->z = 1;
}
void ROIPoolForwardMLUKernelLauncher(Tensor input, Tensor rois, Tensor output,
Tensor argmax, int pooled_height,
int pooled_width, float spatial_scale) {
// Check dtype.
TORCH_CHECK(
input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf,
"input type should be Float or Half, got ", input.scalar_type());
TORCH_CHECK(input.scalar_type() == rois.scalar_type(),
"rois should have the same type as input");
// Check dtype relationship.
TORCH_CHECK(
argmax.scalar_type() == at::kLong || argmax.scalar_type() == at::kInt,
"argmax type should be Int or Long, got ", argmax.scalar_type());
// Check shape.
TORCH_CHECK(input.dim() == 4, "input should be 4d tensor, got ", input.dim(),
"D");
TORCH_CHECK(rois.dim() == 2, "rois should be 2d tensor, got ", rois.dim(),
"D");
TORCH_CHECK(argmax.dim() == 4, "argmax should be 4d tensor, got ",
argmax.dim(), "D");
TORCH_CHECK(spatial_scale > 0 && spatial_scale <= 1,
"spatial_scale should be within (0, 1], got ", spatial_scale);
// compute kernel params
auto batch = input.size(0);
auto height = input.size(2);
auto width = input.size(3);
auto channels = input.size(1);
auto rois_num = output.size(0);
if (output.numel() == 0) {
output = at::zeros({rois_num, channels, pooled_height, pooled_width},
input.options());
return;
}
if (argmax.numel() == 0) {
argmax = at::zeros({rois_num, channels, pooled_height, pooled_width},
argmax.options());
return;
}
// zero element check
if (input.numel() == 0 || rois.numel() == 0 || output.numel() == 0 ||
argmax.numel() == 0) {
return;
}
auto memory_format =
torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim());
auto input_ = torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format);
at::Tensor output_ =
at::empty({rois_num, channels, pooled_height, pooled_width},
input.options(), memory_format);
at::Tensor argmax_ =
at::empty({rois_num, channels, pooled_height, pooled_width},
argmax.options(), memory_format);
// calculate task dimension
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFuncForward(rois_num * pooled_height * pooled_width, &k_dim, &k_type);
// get compute queue
auto queue = torch_mlu::getCurQueue();
// get ptr of tensors
auto input_impl = torch_mlu::getMluTensorImpl(input_);
auto input_ptr = input_impl->cnnlMalloc();
auto rois_impl = torch_mlu::getMluTensorImpl(rois);
auto rois_ptr = rois_impl->cnnlMalloc();
auto output_impl = torch_mlu::getMluTensorImpl(output_);
auto output_ptr = output_impl->cnnlMalloc();
auto argmax_impl = torch_mlu::getMluTensorImpl(argmax_);
auto argmax_ptr = argmax_impl->cnnlMalloc();
// get comput dtype of input
cnrtDataType_t data_type = torch_mlu::toCnrtDtype(input_.dtype());
// launch kernel
CNLOG(INFO) << "Launch Kernel MLUKernelRoiPoolForward<<<" << k_dim.x << ", "
<< k_dim.y << ", " << k_dim.z << ">>>";
KernelRoiPoolForward(k_dim, k_type, queue, data_type, input_ptr, rois_ptr,
batch, channels, height, width, pooled_height,
pooled_width, rois_num, spatial_scale, output_ptr,
(int *)argmax_ptr);
output.copy_(output_);
argmax.copy_(argmax_);
}
// policy function for backward
static void policyFuncBackward(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type) {
*k_type = CNRT_FUNC_TYPE_UNION1;
k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster);
k_dim->y = torch_mlu::getDeviceAttr(cnrtAttrClusterCount);
k_dim->z = 1;
}
void ROIPoolBackwardMLUKernelLauncher(Tensor grad_output, Tensor rois,
Tensor argmax, Tensor grad_input,
int pooled_height, int pooled_width,
float spatial_scale) {
// Check dtype.
TORCH_CHECK(
argmax.scalar_type() == at::kLong || argmax.scalar_type() == at::kInt,
"argmax type should be Int or Long, got ", argmax.scalar_type());
TORCH_CHECK((grad_output.scalar_type() == at::kFloat ||
grad_output.scalar_type() == at::kHalf),
"grad_output type should be FLoat or Half, got ",
grad_output.scalar_type());
// Check dtype relationship.
TORCH_CHECK((rois.scalar_type() == grad_output.scalar_type()),
"rois should have the same type as grad_output");
// Check shape.
TORCH_CHECK(grad_output.dim() == 4, "grad_output should be 4d tensor, got ",
grad_output.dim(), "D");
TORCH_CHECK(rois.dim() == 2, "rois should be 2d tensor, got ", rois.dim(),
"D");
TORCH_CHECK(argmax.dim() == 4, "argmax should be 4d tensor, got ",
argmax.dim(), "D");
TORCH_CHECK(spatial_scale > 0 && spatial_scale <= 1,
"spatial_scale should be within (0, 1], got ", spatial_scale);
// Check relationship between tensor.
// Check the relationship of n.
TORCH_CHECK(grad_output.size(0) == rois.size(0),
"grad_output.size(0) = ", grad_output.size(0),
", while rois.size(0) = ", rois.size(0),
". They should be the same.");
// Check the relationship of channels.
TORCH_CHECK(grad_output.size(1) == argmax.size(1),
"grad_output.size(1) = ", grad_output.size(1),
", while argmax.size(1) = ", argmax.size(1),
". They should be the same.");
// Check the relationship of height and width.
TORCH_CHECK(grad_output.size(2) == argmax.size(2),
"argmax.size(2) = ", argmax.size(2),
", while grad_output.size(2) = ", grad_output.size(2),
". They should be the same.");
TORCH_CHECK(grad_output.size(3) == argmax.size(3),
"argmax.size(3) = ", argmax.size(3),
", while grad_output.size(3) = ", grad_output.size(3),
". They should be the same.");
// Check zero element.
if (grad_output.numel() == 0 || rois.numel() == 0 || argmax.numel() == 0 ||
grad_input.numel() == 0) {
// return if zero-element
return;
}
auto memory_format =
torch_mlu::cnnl::ops::get_channels_last_memory_format(grad_output.dim());
auto grad_output_ =
torch_mlu::cnnl::ops::cnnl_contiguous(grad_output, memory_format);
auto argmax_ = torch_mlu::cnnl::ops::cnnl_contiguous(argmax, memory_format);
int boxes_num = grad_output.size(0);
int no = grad_input.size(0);
int channels = grad_input.size(1);
int height = grad_input.size(2);
int width = grad_input.size(3);
auto grad_input_ = at::empty({no, channels, height, width},
grad_input.options(), memory_format)
.zero_();
// get tensor impl
auto grad_output_impl = torch_mlu::getMluTensorImpl(grad_output_);
auto rois_impl = torch_mlu::getMluTensorImpl(rois);
auto argmax_impl = torch_mlu::getMluTensorImpl(argmax_);
auto grad_input_impl = torch_mlu::getMluTensorImpl(grad_input_);
// get compute queue
auto queue = torch_mlu::getCurQueue();
// get mlu ptr
auto grad_output_ptr = grad_output_impl->cnnlMalloc();
auto rois_ptr = rois_impl->cnnlMalloc();
auto argmax_ptr = argmax_impl->cnnlMalloc();
auto grad_input_ptr = grad_input_impl->cnnlMalloc();
// calculate task dimension
cnrtDataType_t k_dtype = torch_mlu::toCnrtDtype(grad_input.dtype());
cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFuncBackward(&k_dim, &k_type);
CNLOG(INFO) << "Launch Kernel MLUKernelRoiPoolBackward<<<" << k_dim.x << ", "
<< k_dim.y << ", " << k_dim.z << ">>>";
KernelRoiPoolBackward(k_dim, k_type, queue, k_dtype, grad_output_ptr,
rois_ptr, (int *)argmax_ptr, grad_input_ptr, boxes_num,
pooled_height, pooled_width, channels, no, height,
width, spatial_scale);
grad_input.copy_(grad_input_);
}
void roi_pool_forward_mlu(Tensor input, Tensor rois, Tensor output,
Tensor argmax, int pooled_height, int pooled_width,
float spatial_scale) {
ROIPoolForwardMLUKernelLauncher(input, rois, output, argmax, pooled_height,
pooled_width, spatial_scale);
}
void roi_pool_backward_mlu(Tensor grad_output, Tensor rois, Tensor argmax,
Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale) {
ROIPoolBackwardMLUKernelLauncher(grad_output, rois, argmax, grad_input,
pooled_height, pooled_width, spatial_scale);
}
void roi_pool_forward_impl(Tensor input, Tensor rois, Tensor output,
Tensor argmax, int pooled_height, int pooled_width,
float spatial_scale);
void roi_pool_backward_impl(Tensor grad_output, Tensor rois, Tensor argmax,
Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale);
REGISTER_DEVICE_IMPL(roi_pool_forward_impl, MLU, roi_pool_forward_mlu);
REGISTER_DEVICE_IMPL(roi_pool_backward_impl, MLU, roi_pool_backward_mlu);
...@@ -2,8 +2,11 @@ ...@@ -2,8 +2,11 @@
import os import os
import numpy as np import numpy as np
import pytest
import torch import torch
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
_USING_PARROTS = True _USING_PARROTS = True
try: try:
from parrots.autograd import gradcheck from parrots.autograd import gradcheck
...@@ -54,9 +57,7 @@ class TestRoiPool: ...@@ -54,9 +57,7 @@ class TestRoiPool:
else: else:
gradcheck(froipool, (x, rois), eps=1e-2, atol=1e-2) gradcheck(froipool, (x, rois), eps=1e-2, atol=1e-2)
def _test_roipool_allclose(self, dtype=torch.float): def _test_roipool_allclose(self, device, dtype=torch.float):
if not torch.cuda.is_available():
return
from mmcv.ops import roi_pool from mmcv.ops import roi_pool
pool_h = 2 pool_h = 2
pool_w = 2 pool_w = 2
...@@ -69,15 +70,32 @@ class TestRoiPool: ...@@ -69,15 +70,32 @@ class TestRoiPool:
np_grad = np.array(output[1]) np_grad = np.array(output[1])
x = torch.tensor( x = torch.tensor(
np_input, dtype=dtype, device='cuda', requires_grad=True) np_input, dtype=dtype, device=device, requires_grad=True)
rois = torch.tensor(np_rois, dtype=dtype, device='cuda') rois = torch.tensor(np_rois, dtype=dtype, device=device)
output = roi_pool(x, rois, (pool_h, pool_w), spatial_scale) output = roi_pool(x, rois, (pool_h, pool_w), spatial_scale)
output.backward(torch.ones_like(output)) output.backward(torch.ones_like(output))
assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3) assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3)
assert np.allclose(x.grad.data.cpu().numpy(), np_grad, 1e-3) assert np.allclose(x.grad.data.cpu().numpy(), np_grad, 1e-3)
def test_roipool_allclose(self): @pytest.mark.parametrize('device', [
self._test_roipool_allclose(torch.double) pytest.param(
self._test_roipool_allclose(torch.float) 'cuda',
self._test_roipool_allclose(torch.half) marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
@pytest.mark.parametrize('dtype', [
torch.float,
pytest.param(
torch.double,
marks=pytest.mark.skipif(
IS_MLU_AVAILABLE,
reason='MLU does not support for 64-bit floating point')),
torch.half
])
def test_roipool_allclose(self, device, dtype):
self._test_roipool_allclose(device, dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment