Unverified Commit 0970ae94 authored by Jerry Jiarui XU's avatar Jerry Jiarui XU Committed by GitHub
Browse files

Add cpu ops (#370)

* update actiion

* remove 1.3 cuda

* reorder

* add no torch

* fixed version

* make py3.8 on pt1.5

* make py3.8 on pt1.5

* remove torch 1.3

* disable py3.8

* pip

* merge master with cuda compile fix

* add cpu roi align

* fixed test

* fixed no torch

* add CUDA_ARGS

* use one line

* gencode=61

* seperate jobs

* update lint

* use parametrize test

* formart and rename

* unit test for all
parent ab77af82
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: build
on: [push, pull_request]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.7, 3.8]
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
- name: Set up Python 3.7
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
python-version: 3.7
- name: Install linting dependencies
run: |
python -m pip install --upgrade pip
......@@ -34,20 +28,99 @@ jobs:
source: mmcv/ops/csrc
extensions: h,c,cpp,hpp,cu,cuh
style: google
build_cpu:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7]
torch: [1.4.0]
include:
- torch: 1.4.0
torchvision: 0.4.2
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install system dependencies
run: sudo apt-get update && sudo apt-get install -y ffmpeg libturbojpeg
- name: Install Pillow
run: pip install Pillow==6.2.2
if: ${{matrix.torchvision == '0.4.2'}}
- name: Install PyTorch
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install unittest dependencies
run: pip install pytest coverage lmdb PyTurboJPEG
- name: Build and install
run: rm -rf .eggs && pip install -e .
- name: Run unittests and generate coverage report
run: |
coverage run --branch --source=mmcv -m pytest tests/
coverage xml
coverage report -m
build_cuda:
runs-on: ubuntu-latest
env:
CUDA: 10.1.105-1
CUDA_SHORT: 10.1
UBUNTU_VERSION: ubuntu1804
FORCE_CUDA: 1
MMCV_CUDA_ARGS: -gencode=arch=compute_61,code=sm_61
strategy:
matrix:
python-version: [3.6, 3.7]
torch: [1.3.0, 1.5.0]
include:
- torch: 1.3.0
torchvision: 0.4.2
- torch: 1.5.0
torchvision: 0.6.0
- python-version: 3.8
torch: 1.5.0
torchvision: 0.6.0
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install CUDA
run: |
export INSTALLER=cuda-repo-${UBUNTU_VERSION}_${CUDA}_amd64.deb
wget http://developer.download.nvidia.com/compute/cuda/repos/${UBUNTU_VERSION}/x86_64/${INSTALLER}
sudo dpkg -i ${INSTALLER}
wget https://developer.download.nvidia.com/compute/cuda/repos/${UBUNTU_VERSION}/x86_64/7fa2af80.pub
sudo apt-key add 7fa2af80.pub
sudo apt update -qq
sudo apt install -y cuda-${CUDA_SHORT/./-} cuda-cufft-dev-${CUDA_SHORT/./-}
sudo apt clean
export CUDA_HOME=/usr/local/cuda-${CUDA_SHORT}
export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${CUDA_HOME}/include:${LD_LIBRARY_PATH}
export PATH=${CUDA_HOME}/bin:${PATH}
sudo apt-get install -y ninja-build
- name: Install Pillow
run: pip install Pillow==6.2.2
if: ${{matrix.torchvision == '0.4.2'}}
- name: Install PyTorch
run: pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}}
- name: Install system dependencies
run: sudo apt-get update && sudo apt-get install -y ffmpeg libturbojpeg
- name: Install unittest dependencies
run: |
pip install pytest coverage lmdb PyTurboJPEG
pip install torch==1.4.0+cpu torchvision==0.5.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
run: pip install pytest coverage lmdb PyTurboJPEG
- name: Build and install
run: rm -rf .eggs && pip install -e .
- name: Run unittests and generate coverage report
run: |
coverage run --branch --source=mmcv -m pytest tests/
coverage xml
coverage report -m
# Only upload coverage report for python3.7 && pytorch1.5
- name: Upload coverage to Codecov
if: ${{matrix.torch == '1.5.0' && matrix.python-version == '3.7'}}
uses: codecov/codecov-action@master
with:
file: ./coverage.xml
......@@ -55,3 +128,36 @@ jobs:
env_vars: OS,PYTHON
name: codecov-umbrella
fail_ci_if_error: false
build_no_torch:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7]
torch: [1.4.0]
include:
- torch: 1.4.0
torchvision: 0.4.2
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install system dependencies
run: sudo apt-get update && sudo apt-get install -y ffmpeg libturbojpeg
- name: Install unittest dependencies
run: pip install pytest coverage lmdb PyTurboJPEG
- name: Build and install
run: rm -rf .eggs && pip install -e .
- name: Install Pillow
run: pip install Pillow==6.2.2
if: ${{matrix.torchvision == '0.4.2'}}
- name: Install PyTorch
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Run unittests and generate coverage report
run: |
coverage run --branch --source=mmcv -m pytest tests/
coverage xml
coverage report -m
......@@ -35,6 +35,38 @@ void roi_align_backward_cuda(Tensor grad_output, Tensor rois, Tensor argmax_y,
}
#endif
void ROIAlignForwardCPULauncher(Tensor input, Tensor rois, Tensor output,
Tensor argmax_y, Tensor argmax_x,
int aligned_height, int aligned_width,
float spatial_scale, int sampling_ratio,
int pool_mode, bool aligned);
void ROIAlignBackwardCPULauncher(Tensor grad_output, Tensor rois,
Tensor argmax_y, Tensor argmax_x,
Tensor grad_input, int aligned_height,
int aligned_width, float spatial_scale,
int sampling_ratio, int pool_mode,
bool aligned);
void roi_align_forward_cpu(Tensor input, Tensor rois, Tensor output,
Tensor argmax_y, Tensor argmax_x, int aligned_height,
int aligned_width, float spatial_scale,
int sampling_ratio, int pool_mode, bool aligned) {
ROIAlignForwardCPULauncher(input, rois, output, argmax_y, argmax_x,
aligned_height, aligned_width, spatial_scale,
sampling_ratio, pool_mode, aligned);
}
void roi_align_backward_cpu(Tensor grad_output, Tensor rois, Tensor argmax_y,
Tensor argmax_x, Tensor grad_input,
int aligned_height, int aligned_width,
float spatial_scale, int sampling_ratio,
int pool_mode, bool aligned) {
ROIAlignBackwardCPULauncher(grad_output, rois, argmax_y, argmax_x, grad_input,
aligned_height, aligned_width, spatial_scale,
sampling_ratio, pool_mode, aligned);
}
void roi_align_forward(Tensor input, Tensor rois, Tensor output,
Tensor argmax_y, Tensor argmax_x, int aligned_height,
int aligned_width, float spatial_scale,
......@@ -54,7 +86,14 @@ void roi_align_forward(Tensor input, Tensor rois, Tensor output,
AT_ERROR("RoIAlign is not compiled with GPU support");
#endif
} else {
AT_ERROR("RoIAlign is not implemented on CPU");
CHECK_CPU_INPUT(input);
CHECK_CPU_INPUT(rois);
CHECK_CPU_INPUT(output);
CHECK_CPU_INPUT(argmax_y);
CHECK_CPU_INPUT(argmax_x);
roi_align_forward_cpu(input, rois, output, argmax_y, argmax_x,
aligned_height, aligned_width, spatial_scale,
sampling_ratio, pool_mode, aligned);
}
}
......@@ -77,6 +116,14 @@ void roi_align_backward(Tensor grad_output, Tensor rois, Tensor argmax_y,
AT_ERROR("RoIAlign is not compiled with GPU support");
#endif
} else {
AT_ERROR("RoIAlign is not implemented on CPU");
CHECK_CPU_INPUT(grad_output);
CHECK_CPU_INPUT(rois);
CHECK_CPU_INPUT(argmax_y);
CHECK_CPU_INPUT(argmax_x);
CHECK_CPU_INPUT(grad_input);
roi_align_backward_cpu(grad_output, rois, argmax_y, argmax_x, grad_input,
aligned_height, aligned_width, spatial_scale,
sampling_ratio, pool_mode, aligned);
}
}
// Modified from
// https://github.com/facebookresearch/detectron2/tree/master/detectron2/layers/csrc/ROIAlign
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include "../pytorch_cpp_helper.hpp"
// implementation taken from Caffe2
template <typename T>
struct PreCalc {
int pos1;
int pos2;
int pos3;
int pos4;
T w1;
T w2;
T w3;
T w4;
};
template <typename T>
void pre_calc_for_bilinear_interpolate(
const int height, const int width, const int pooled_height,
const int pooled_width, const int iy_upper, const int ix_upper,
T roi_start_h, T roi_start_w, T bin_size_h, T bin_size_w,
int roi_bin_grid_h, int roi_bin_grid_w, std::vector<PreCalc<T>>& pre_calc) {
int pre_calc_index = 0;
for (int ph = 0; ph < pooled_height; ph++) {
for (int pw = 0; pw < pooled_width; pw++) {
for (int iy = 0; iy < iy_upper; iy++) {
const T yy = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (int ix = 0; ix < ix_upper; ix++) {
const T xx = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T x = xx;
T y = yy;
// deal with: inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
PreCalc<T> pc;
pc.pos1 = 0;
pc.pos2 = 0;
pc.pos3 = 0;
pc.pos4 = 0;
pc.w1 = 0;
pc.w2 = 0;
pc.w3 = 0;
pc.w4 = 0;
pre_calc[pre_calc_index] = pc;
pre_calc_index += 1;
continue;
}
if (y <= 0) {
y = 0;
}
if (x <= 0) {
x = 0;
}
int y_low = (int)y;
int x_low = (int)x;
int y_high;
int x_high;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
// save weights and indices
PreCalc<T> pc;
pc.pos1 = y_low * width + x_low;
pc.pos2 = y_low * width + x_high;
pc.pos3 = y_high * width + x_low;
pc.pos4 = y_high * width + x_high;
pc.w1 = w1;
pc.w2 = w2;
pc.w3 = w3;
pc.w4 = w4;
pre_calc[pre_calc_index] = pc;
pre_calc_index += 1;
}
}
}
}
}
template <typename T>
void ROIAlignForward(const int nthreads, const T* input, const T* rois,
T* output, T* argmax_y, T* argmax_x,
const int pooled_height, const int pooled_width,
const T spatial_scale, const int sampling_ratio,
const int pool_mode, // 0 - max pool, 1 - avg pool
const bool aligned, const int channels, const int height,
const int width) {
int n_rois = nthreads / channels / pooled_width / pooled_height;
// (n, c, ph, pw) is an element in the pooled output
// can be parallelized using omp
// #pragma omp parallel for num_threads(32)
for (int n = 0; n < n_rois; n++) {
int index_n = n * channels * pooled_width * pooled_height;
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
// Do not use rounding; this implementation detail is critical
T offset = aligned ? (T)0.5 : (T)0.0;
T roi_start_w = offset_rois[1] * spatial_scale - offset;
T roi_start_h = offset_rois[2] * spatial_scale - offset;
T roi_end_w = offset_rois[3] * spatial_scale - offset;
T roi_end_h = offset_rois[4] * spatial_scale - offset;
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
if (aligned) {
AT_ASSERTM(roi_width >= 0 && roi_height >= 0,
"ROIs in ROIAlign cannot have non-negative size!");
} else { // for backward-compatibility only
roi_width = std::max(roi_width, (T)1.);
roi_height = std::max(roi_height, (T)1.);
}
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
// When the grid is empty, output zeros == 0/1, instead of NaN.
const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
// we want to precalculate indices and weights shared by all channels,
// this is the key point of optimization
std::vector<PreCalc<T>> pre_calc(roi_bin_grid_h * roi_bin_grid_w *
pooled_width * pooled_height);
pre_calc_for_bilinear_interpolate(
height, width, pooled_height, pooled_width, roi_bin_grid_h,
roi_bin_grid_w, roi_start_h, roi_start_w, bin_size_h, bin_size_w,
roi_bin_grid_h, roi_bin_grid_w, pre_calc);
for (int c = 0; c < channels; c++) {
int index_n_c = index_n + c * pooled_width * pooled_height;
const T* offset_input =
input + (roi_batch_ind * channels + c) * height * width;
int pre_calc_index = 0;
for (int ph = 0; ph < pooled_height; ph++) {
for (int pw = 0; pw < pooled_width; pw++) {
int index = index_n_c + ph * pooled_width + pw;
T output_val = 0.;
T maxval = -10000;
T maxidx_y = -1.f, maxidx_x = -1.f;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
PreCalc<T> pc = pre_calc[pre_calc_index];
T val = pc.w1 * offset_input[pc.pos1] +
pc.w2 * offset_input[pc.pos2] +
pc.w3 * offset_input[pc.pos3] +
pc.w4 * offset_input[pc.pos4];
if (val > maxval) {
maxval = val;
maxidx_y = y;
maxidx_x = x;
}
output_val += val;
pre_calc_index += 1;
}
}
if (pool_mode == 0) {
// We do max pooling inside a bin
output[index] = maxval;
argmax_y[index] = maxidx_y;
argmax_x[index] = maxidx_x;
} else if (pool_mode == 1) {
// We do average (integral) pooling inside a bin
output[index] = output_val / count;
} // if
} // for pw
} // for ph
} // for c
} // for n
}
template <typename T>
void bilinear_interpolate_gradient(const int height, const int width, T y, T x,
T& w1, T& w2, T& w3, T& w4, int& x_low,
int& x_high, int& y_low, int& y_high,
const int index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
w1 = w2 = w3 = w4 = 0.;
x_low = x_high = y_low = y_high = -1;
return;
}
if (y <= 0) y = 0;
if (x <= 0) x = 0;
y_low = (int)y;
x_low = (int)x;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
// reference in forward
// T v1 = input[y_low * width + x_low];
// T v2 = input[y_low * width + x_high];
// T v3 = input[y_high * width + x_low];
// T v4 = input[y_high * width + x_high];
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
return;
}
template <class T>
inline void add(T* address, const T& val) {
*address += val;
}
template <typename T>
void ROIAlignBackward(const int nthreads, const T* grad_output, const T* rois,
const T* argmax_y, const T* argmax_x, T* grad_input,
const int pooled_height, const int pooled_width,
const T spatial_scale, const int sampling_ratio,
const int pool_mode, // 0 - max pool, 1 - avg pool
const bool aligned, const int channels, const int height,
const int width, const int n_stride, const int c_stride,
const int h_stride, const int w_stride) {
for (int index = 0; index < nthreads; index++) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
// Do not use rounding; this implementation detail is critical
T offset = aligned ? (T)0.5 : (T)0.0;
T roi_start_w = offset_rois[1] * spatial_scale - offset;
T roi_start_h = offset_rois[2] * spatial_scale - offset;
T roi_end_w = offset_rois[3] * spatial_scale - offset;
T roi_end_h = offset_rois[4] * spatial_scale - offset;
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
if (aligned) {
AT_ASSERTM(roi_width >= 0 && roi_height >= 0,
"ROIs in ROIAlign do not have non-negative size!");
} else { // for backward-compatibility only
roi_width = std::max(roi_width, (T)1.);
roi_height = std::max(roi_height, (T)1.);
}
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
T* offset_grad_input =
grad_input + ((roi_batch_ind * channels + c) * height * width);
int output_offset = n * n_stride + c * c_stride;
const T* offset_grad_output = grad_output + output_offset;
const T grad_output_this_bin =
offset_grad_output[ph * h_stride + pw * w_stride];
if (pool_mode == 0) {
// We do max pooling inside a bin
T y = argmax_y[index], x = argmax_x[index];
if (y != -1.f) {
T w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
x_low, x_high, y_low, y_high, index);
T g1 = grad_output_this_bin * w1;
T g2 = grad_output_this_bin * w2;
T g3 = grad_output_this_bin * w3;
T g4 = grad_output_this_bin * w4;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
// atomic add is not needed for now since it is single threaded
add(offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
add(offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
add(offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
add(offset_grad_input + y_high * width + x_high, static_cast<T>(g4));
} // if
} // mode
} else if (pool_mode == 1) {
// We do average (integral) pooling inside a bin
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
x_low, x_high, y_low, y_high, index);
T g1 = grad_output_this_bin * w1 / count;
T g2 = grad_output_this_bin * w2 / count;
T g3 = grad_output_this_bin * w3 / count;
T g4 = grad_output_this_bin * w4 / count;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
// atomic add is not needed for now since it is single threaded
add(offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
add(offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
add(offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
add(offset_grad_input + y_high * width + x_high,
static_cast<T>(g4));
} // if
} // ix
} // iy
} // mode
} // for
} // ROIAlignBackward
void ROIAlignForwardCPULauncher(Tensor input, Tensor rois, Tensor output,
Tensor argmax_y, Tensor argmax_x,
int aligned_height, int aligned_width,
float spatial_scale, int sampling_ratio,
int pool_mode, bool aligned) {
int output_size = output.numel();
int channels = input.size(1);
int height = input.size(2);
int width = input.size(3);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "ROIAlign_forward", [&] {
ROIAlignForward<scalar_t>(
output_size, input.data_ptr<scalar_t>(), rois.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(), argmax_y.data_ptr<scalar_t>(),
argmax_x.data_ptr<scalar_t>(), aligned_height, aligned_width,
static_cast<scalar_t>(spatial_scale), sampling_ratio, pool_mode,
aligned, channels, height, width);
});
}
void ROIAlignBackwardCPULauncher(Tensor grad_output, Tensor rois,
Tensor argmax_y, Tensor argmax_x,
Tensor grad_input, int aligned_height,
int aligned_width, float spatial_scale,
int sampling_ratio, int pool_mode,
bool aligned) {
int output_size = grad_output.numel();
int channels = grad_input.size(1);
int height = grad_input.size(2);
int width = grad_input.size(3);
// get stride values to ensure indexing into gradients is correct.
int n_stride = grad_output.stride(0);
int c_stride = grad_output.stride(1);
int h_stride = grad_output.stride(2);
int w_stride = grad_output.stride(3);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(), "ROIAlign_backward", [&] {
ROIAlignBackward<scalar_t>(
output_size, grad_output.data_ptr<scalar_t>(),
rois.data_ptr<scalar_t>(), argmax_y.data_ptr<scalar_t>(),
argmax_x.data_ptr<scalar_t>(), grad_input.data_ptr<scalar_t>(),
aligned_height, aligned_width, static_cast<scalar_t>(spatial_scale),
sampling_ratio, pool_mode, aligned, channels, height, width,
n_stride, c_stride, h_stride, w_stride);
});
}
import os
import numpy as np
import pytest
import torch
_USING_PARROTS = True
......@@ -10,86 +9,91 @@ except ImportError:
from torch.autograd import gradcheck
_USING_PARROTS = False
cur_dir = os.path.dirname(os.path.abspath(__file__))
inputs = [([[[[1., 2.], [3., 4.]]]], [[0., 0., 0., 1., 1.]]),
([[[[1., 2.], [3., 4.]], [[4., 3.], [2.,
1.]]]], [[0., 0., 0., 1., 1.]]),
([[[[1., 2., 5., 6.], [3., 4., 7., 8.], [9., 10., 13., 14.],
[11., 12., 15., 16.]]]], [[0., 0., 0., 3., 3.]])]
outputs = [([[[[1.0, 1.25], [1.5, 1.75]]]], [[[[3.0625, 0.4375],
[0.4375, 0.0625]]]]),
# yapf:disable
inputs = [([[[[1., 2.], [3., 4.]]]],
[[0., 0., 0., 1., 1.]]),
([[[[1., 2.], [3., 4.]],
[[4., 3.], [2., 1.]]]],
[[0., 0., 0., 1., 1.]]),
([[[[1., 2., 5., 6.], [3., 4., 7., 8.],
[9., 10., 13., 14.], [11., 12., 15., 16.]]]],
[[0., 0., 0., 3., 3.]])]
outputs = [([[[[1.0, 1.25], [1.5, 1.75]]]],
[[[[3.0625, 0.4375], [0.4375, 0.0625]]]]),
([[[[1.0, 1.25], [1.5, 1.75]],
[[4.0, 3.75], [3.5, 3.25]]]], [[[[3.0625, 0.4375],
[0.4375, 0.0625]],
[[3.0625, 0.4375],
[0.4375, 0.0625]]]]),
[[4.0, 3.75], [3.5, 3.25]]]],
[[[[3.0625, 0.4375], [0.4375, 0.0625]],
[[3.0625, 0.4375], [0.4375, 0.0625]]]]),
([[[[1.9375, 4.75], [7.5625, 10.375]]]],
[[[[0.47265625, 0.42968750, 0.42968750, 0.04296875],
[0.42968750, 0.39062500, 0.39062500, 0.03906250],
[0.42968750, 0.39062500, 0.39062500, 0.03906250],
[0.04296875, 0.03906250, 0.03906250, 0.00390625]]]])]
# yapf:enable
pool_h = 2
pool_w = 2
spatial_scale = 1.0
sampling_ratio = 2
class TestRoiAlign(object):
def test_roialign_gradcheck(self):
if not torch.cuda.is_available():
return
def _test_roialign_gradcheck(device, dtype):
if not torch.cuda.is_available() and device == 'cuda':
pytest.skip('test requires GPU')
try:
from mmcv.ops import RoIAlign
pool_h = 2
pool_w = 2
spatial_scale = 1.0
sampling_ratio = 2
except ModuleNotFoundError:
pytest.skip('RoIAlign op is not successfully compiled')
if dtype is torch.half:
pytest.skip('grad check does not support fp16')
for case in inputs:
np_input = np.array(case[0])
np_rois = np.array(case[1])
for case in inputs:
np_input = np.array(case[0])
np_rois = np.array(case[1])
x = torch.tensor(
np_input, dtype=dtype, device=device, requires_grad=True)
rois = torch.tensor(np_rois, dtype=dtype, device=device)
x = torch.tensor(np_input, device='cuda', requires_grad=True)
rois = torch.tensor(np_rois, device='cuda')
froipool = RoIAlign((pool_h, pool_w), spatial_scale, sampling_ratio)
froipool = RoIAlign((pool_h, pool_w), spatial_scale,
sampling_ratio)
gradcheck(froipool, (x, rois), eps=1e-5, atol=1e-5)
if _USING_PARROTS:
pass
# gradcheck(froipool, (x, rois), no_grads=[rois])
else:
gradcheck(froipool, (x, rois), eps=1e-2, atol=1e-2)
def _test_roipool_allclose(self, dtype=torch.float):
if not torch.cuda.is_available():
return
def _test_roialign_allclose(device, dtype):
if not torch.cuda.is_available() and device == 'cuda':
pytest.skip('test requires GPU')
try:
from mmcv.ops import roi_align
pool_h = 2
pool_w = 2
spatial_scale = 1.0
sampling_ratio = 2
for case, output in zip(inputs, outputs):
np_input = np.array(case[0])
np_rois = np.array(case[1])
np_output = np.array(output[0])
np_grad = np.array(output[1])
x = torch.tensor(
np_input, dtype=dtype, device='cuda', requires_grad=True)
rois = torch.tensor(np_rois, dtype=dtype, device='cuda')
output = roi_align(x, rois, (pool_h, pool_w), spatial_scale,
sampling_ratio, 'avg', True)
output.backward(torch.ones_like(output))
assert np.allclose(
output.data.type(torch.float).cpu().numpy(),
np_output,
atol=1e-3)
assert np.allclose(
x.grad.data.type(torch.float).cpu().numpy(),
np_grad,
atol=1e-3)
def test_roipool_allclose(self):
self._test_roipool_allclose(torch.float)
self._test_roipool_allclose(torch.double)
self._test_roipool_allclose(torch.half)
except ModuleNotFoundError:
pytest.skip('test requires compilation')
pool_h = 2
pool_w = 2
spatial_scale = 1.0
sampling_ratio = 2
for case, output in zip(inputs, outputs):
np_input = np.array(case[0])
np_rois = np.array(case[1])
np_output = np.array(output[0])
np_grad = np.array(output[1])
x = torch.tensor(
np_input, dtype=dtype, device=device, requires_grad=True)
rois = torch.tensor(np_rois, dtype=dtype, device=device)
output = roi_align(x, rois, (pool_h, pool_w), spatial_scale,
sampling_ratio, 'avg', True)
output.backward(torch.ones_like(output))
assert np.allclose(
output.data.type(torch.float).cpu().numpy(), np_output, atol=1e-3)
assert np.allclose(
x.grad.data.type(torch.float).cpu().numpy(), np_grad, atol=1e-3)
@pytest.mark.parametrize('device', ['cuda', 'cpu'])
@pytest.mark.parametrize('dtype', [torch.float, torch.double, torch.half])
def test_roialign(device, dtype):
# check double only
if dtype is torch.double:
_test_roialign_gradcheck(device=device, dtype=dtype)
_test_roialign_allclose(device=device, dtype=dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment