Commit b634945d authored by limm's avatar limm
Browse files

support v0.6

parent 5b3792fc
// Copyright (c) Facebook, Inc. and its affiliates.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include "box_iou_rotated_utils.h"
namespace detectron2 {
// 2D block with 32 * 16 = 512 threads per block
const int BLOCK_DIM_X = 32;
const int BLOCK_DIM_Y = 16;
template <typename T>
__global__ void box_iou_rotated_cuda_kernel(
const int n_boxes1,
const int n_boxes2,
const T* dev_boxes1,
const T* dev_boxes2,
T* dev_ious) {
const int row_start = blockIdx.x * blockDim.x;
const int col_start = blockIdx.y * blockDim.y;
const int row_size = min(n_boxes1 - row_start, blockDim.x);
const int col_size = min(n_boxes2 - col_start, blockDim.y);
__shared__ float block_boxes1[BLOCK_DIM_X * 5];
__shared__ float block_boxes2[BLOCK_DIM_Y * 5];
// It's safe to copy using threadIdx.x since BLOCK_DIM_X >= BLOCK_DIM_Y
if (threadIdx.x < row_size && threadIdx.y == 0) {
block_boxes1[threadIdx.x * 5 + 0] =
dev_boxes1[(row_start + threadIdx.x) * 5 + 0];
block_boxes1[threadIdx.x * 5 + 1] =
dev_boxes1[(row_start + threadIdx.x) * 5 + 1];
block_boxes1[threadIdx.x * 5 + 2] =
dev_boxes1[(row_start + threadIdx.x) * 5 + 2];
block_boxes1[threadIdx.x * 5 + 3] =
dev_boxes1[(row_start + threadIdx.x) * 5 + 3];
block_boxes1[threadIdx.x * 5 + 4] =
dev_boxes1[(row_start + threadIdx.x) * 5 + 4];
}
if (threadIdx.x < col_size && threadIdx.y == 0) {
block_boxes2[threadIdx.x * 5 + 0] =
dev_boxes2[(col_start + threadIdx.x) * 5 + 0];
block_boxes2[threadIdx.x * 5 + 1] =
dev_boxes2[(col_start + threadIdx.x) * 5 + 1];
block_boxes2[threadIdx.x * 5 + 2] =
dev_boxes2[(col_start + threadIdx.x) * 5 + 2];
block_boxes2[threadIdx.x * 5 + 3] =
dev_boxes2[(col_start + threadIdx.x) * 5 + 3];
block_boxes2[threadIdx.x * 5 + 4] =
dev_boxes2[(col_start + threadIdx.x) * 5 + 4];
}
__syncthreads();
if (threadIdx.x < row_size && threadIdx.y < col_size) {
int offset = (row_start + threadIdx.x) * n_boxes2 + col_start + threadIdx.y;
dev_ious[offset] = single_box_iou_rotated<T>(
block_boxes1 + threadIdx.x * 5, block_boxes2 + threadIdx.y * 5);
}
}
at::Tensor box_iou_rotated_cuda(
// input must be contiguous
const at::Tensor& boxes1,
const at::Tensor& boxes2) {
using scalar_t = float;
AT_ASSERTM(
boxes1.scalar_type() == at::kFloat, "boxes1 must be a float tensor");
AT_ASSERTM(
boxes2.scalar_type() == at::kFloat, "boxes2 must be a float tensor");
AT_ASSERTM(boxes1.is_cuda(), "boxes1 must be a CUDA tensor");
AT_ASSERTM(boxes2.is_cuda(), "boxes2 must be a CUDA tensor");
at::cuda::CUDAGuard device_guard(boxes1.device());
auto num_boxes1 = boxes1.size(0);
auto num_boxes2 = boxes2.size(0);
at::Tensor ious =
at::empty({num_boxes1 * num_boxes2}, boxes1.options().dtype(at::kFloat));
bool transpose = false;
if (num_boxes1 > 0 && num_boxes2 > 0) {
scalar_t *data1 = boxes1.data_ptr<scalar_t>(),
*data2 = boxes2.data_ptr<scalar_t>();
if (num_boxes2 > 65535 * BLOCK_DIM_Y) {
AT_ASSERTM(
num_boxes1 <= 65535 * BLOCK_DIM_Y,
"Too many boxes for box_iou_rotated_cuda!");
// x dim is allowed to be large, but y dim cannot,
// so we transpose the two to avoid "invalid configuration argument"
// error. We assume one of them is small. Otherwise the result is hard to
// fit in memory anyway.
std::swap(num_boxes1, num_boxes2);
std::swap(data1, data2);
transpose = true;
}
const int blocks_x =
at::cuda::ATenCeilDiv(static_cast<int>(num_boxes1), BLOCK_DIM_X);
const int blocks_y =
at::cuda::ATenCeilDiv(static_cast<int>(num_boxes2), BLOCK_DIM_Y);
dim3 blocks(blocks_x, blocks_y);
dim3 threads(BLOCK_DIM_X, BLOCK_DIM_Y);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
box_iou_rotated_cuda_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
num_boxes1,
num_boxes2,
data1,
data2,
(scalar_t*)ious.data_ptr<scalar_t>());
AT_CUDA_CHECK(cudaGetLastError());
}
// reshape from 1d array to 2d array
auto shape = std::vector<int64_t>{num_boxes1, num_boxes2};
if (transpose) {
return ious.view(shape).t();
} else {
return ious.view(shape);
}
}
} // namespace detectron2
// Copyright (c) Facebook, Inc. and its affiliates.
#pragma once
#include <cassert>
#include <cmath>
#if defined(__CUDACC__) || __HCC__ == 1 || __HIP__ == 1
// Designates functions callable from the host (CPU) and the device (GPU)
#define HOST_DEVICE __host__ __device__
#define HOST_DEVICE_INLINE HOST_DEVICE __forceinline__
#else
#include <algorithm>
#define HOST_DEVICE
#define HOST_DEVICE_INLINE HOST_DEVICE inline
#endif
namespace detectron2 {
namespace {
template <typename T>
struct RotatedBox {
T x_ctr, y_ctr, w, h, a;
};
template <typename T>
struct Point {
T x, y;
HOST_DEVICE_INLINE Point(const T& px = 0, const T& py = 0) : x(px), y(py) {}
HOST_DEVICE_INLINE Point operator+(const Point& p) const {
return Point(x + p.x, y + p.y);
}
HOST_DEVICE_INLINE Point& operator+=(const Point& p) {
x += p.x;
y += p.y;
return *this;
}
HOST_DEVICE_INLINE Point operator-(const Point& p) const {
return Point(x - p.x, y - p.y);
}
HOST_DEVICE_INLINE Point operator*(const T coeff) const {
return Point(x * coeff, y * coeff);
}
};
template <typename T>
HOST_DEVICE_INLINE T dot_2d(const Point<T>& A, const Point<T>& B) {
return A.x * B.x + A.y * B.y;
}
// R: result type. can be different from input type
template <typename T, typename R = T>
HOST_DEVICE_INLINE R cross_2d(const Point<T>& A, const Point<T>& B) {
return static_cast<R>(A.x) * static_cast<R>(B.y) -
static_cast<R>(B.x) * static_cast<R>(A.y);
}
template <typename T>
HOST_DEVICE_INLINE void get_rotated_vertices(
const RotatedBox<T>& box,
Point<T> (&pts)[4]) {
// M_PI / 180. == 0.01745329251
double theta = box.a * 0.01745329251;
T cosTheta2 = (T)cos(theta) * 0.5f;
T sinTheta2 = (T)sin(theta) * 0.5f;
// y: top --> down; x: left --> right
pts[0].x = box.x_ctr + sinTheta2 * box.h + cosTheta2 * box.w;
pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w;
pts[1].x = box.x_ctr - sinTheta2 * box.h + cosTheta2 * box.w;
pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w;
pts[2].x = 2 * box.x_ctr - pts[0].x;
pts[2].y = 2 * box.y_ctr - pts[0].y;
pts[3].x = 2 * box.x_ctr - pts[1].x;
pts[3].y = 2 * box.y_ctr - pts[1].y;
}
template <typename T>
HOST_DEVICE_INLINE int get_intersection_points(
const Point<T> (&pts1)[4],
const Point<T> (&pts2)[4],
Point<T> (&intersections)[24]) {
// Line vector
// A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1]
Point<T> vec1[4], vec2[4];
for (int i = 0; i < 4; i++) {
vec1[i] = pts1[(i + 1) % 4] - pts1[i];
vec2[i] = pts2[(i + 1) % 4] - pts2[i];
}
// When computing the intersection area, it doesn't hurt if we have
// more (duplicated/approximate) intersections/vertices than needed,
// while it can cause drastic difference if we miss an intersection/vertex.
// Therefore, we add an epsilon to relax the comparisons between
// the float point numbers that decide the intersection points.
double EPS = 1e-5;
// Line test - test all line combos for intersection
int num = 0; // number of intersections
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
// Solve for 2x2 Ax=b
T det = cross_2d<T>(vec2[j], vec1[i]);
// This takes care of parallel lines
if (fabs(det) <= 1e-14) {
continue;
}
auto vec12 = pts2[j] - pts1[i];
T t1 = cross_2d<T>(vec2[j], vec12) / det;
T t2 = cross_2d<T>(vec1[i], vec12) / det;
if (t1 > -EPS && t1 < 1.0f + EPS && t2 > -EPS && t2 < 1.0f + EPS) {
intersections[num++] = pts1[i] + vec1[i] * t1;
}
}
}
// Check for vertices of rect1 inside rect2
{
const auto& AB = vec2[0];
const auto& DA = vec2[3];
auto ABdotAB = dot_2d<T>(AB, AB);
auto ADdotAD = dot_2d<T>(DA, DA);
for (int i = 0; i < 4; i++) {
// assume ABCD is the rectangle, and P is the point to be judged
// P is inside ABCD iff. P's projection on AB lies within AB
// and P's projection on AD lies within AD
auto AP = pts1[i] - pts2[0];
auto APdotAB = dot_2d<T>(AP, AB);
auto APdotAD = -dot_2d<T>(AP, DA);
if ((APdotAB > -EPS) && (APdotAD > -EPS) && (APdotAB < ABdotAB + EPS) &&
(APdotAD < ADdotAD + EPS)) {
intersections[num++] = pts1[i];
}
}
}
// Reverse the check - check for vertices of rect2 inside rect1
{
const auto& AB = vec1[0];
const auto& DA = vec1[3];
auto ABdotAB = dot_2d<T>(AB, AB);
auto ADdotAD = dot_2d<T>(DA, DA);
for (int i = 0; i < 4; i++) {
auto AP = pts2[i] - pts1[0];
auto APdotAB = dot_2d<T>(AP, AB);
auto APdotAD = -dot_2d<T>(AP, DA);
if ((APdotAB > -EPS) && (APdotAD > -EPS) && (APdotAB < ABdotAB + EPS) &&
(APdotAD < ADdotAD + EPS)) {
intersections[num++] = pts2[i];
}
}
}
return num;
}
template <typename T>
HOST_DEVICE_INLINE int convex_hull_graham(
const Point<T> (&p)[24],
const int& num_in,
Point<T> (&q)[24],
bool shift_to_zero = false) {
assert(num_in >= 2);
// Step 1:
// Find point with minimum y
// if more than 1 points have the same minimum y,
// pick the one with the minimum x.
int t = 0;
for (int i = 1; i < num_in; i++) {
if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) {
t = i;
}
}
auto& start = p[t]; // starting point
// Step 2:
// Subtract starting point from every points (for sorting in the next step)
for (int i = 0; i < num_in; i++) {
q[i] = p[i] - start;
}
// Swap the starting point to position 0
auto tmp = q[0];
q[0] = q[t];
q[t] = tmp;
// Step 3:
// Sort point 1 ~ num_in according to their relative cross-product values
// (essentially sorting according to angles)
// If the angles are the same, sort according to their distance to origin
T dist[24];
#if defined(__CUDACC__) || __HCC__ == 1 || __HIP__ == 1
// compute distance to origin before sort, and sort them together with the
// points
for (int i = 0; i < num_in; i++) {
dist[i] = dot_2d<T>(q[i], q[i]);
}
// CUDA version
// In the future, we can potentially use thrust
// for sorting here to improve speed (though not guaranteed)
for (int i = 1; i < num_in - 1; i++) {
for (int j = i + 1; j < num_in; j++) {
T crossProduct = cross_2d<T>(q[i], q[j]);
if ((crossProduct < -1e-6) ||
(fabs(crossProduct) < 1e-6 && dist[i] > dist[j])) {
auto q_tmp = q[i];
q[i] = q[j];
q[j] = q_tmp;
auto dist_tmp = dist[i];
dist[i] = dist[j];
dist[j] = dist_tmp;
}
}
}
#else
// CPU version
std::sort(
q + 1, q + num_in, [](const Point<T>& A, const Point<T>& B) -> bool {
T temp = cross_2d<T>(A, B);
if (fabs(temp) < 1e-6) {
return dot_2d<T>(A, A) < dot_2d<T>(B, B);
} else {
return temp > 0;
}
});
// compute distance to origin after sort, since the points are now different.
for (int i = 0; i < num_in; i++) {
dist[i] = dot_2d<T>(q[i], q[i]);
}
#endif
// Step 4:
// Make sure there are at least 2 points (that don't overlap with each other)
// in the stack
int k; // index of the non-overlapped second point
for (k = 1; k < num_in; k++) {
if (dist[k] > 1e-8) {
break;
}
}
if (k == num_in) {
// We reach the end, which means the convex hull is just one point
q[0] = p[t];
return 1;
}
q[1] = q[k];
int m = 2; // 2 points in the stack
// Step 5:
// Finally we can start the scanning process.
// When a non-convex relationship between the 3 points is found
// (either concave shape or duplicated points),
// we pop the previous point from the stack
// until the 3-point relationship is convex again, or
// until the stack only contains two points
for (int i = k + 1; i < num_in; i++) {
while (m > 1) {
auto q1 = q[i] - q[m - 2], q2 = q[m - 1] - q[m - 2];
// cross_2d() uses FMA and therefore computes round(round(q1.x*q2.y) -
// q2.x*q1.y) So it may not return 0 even when q1==q2. Therefore we
// compare round(q1.x*q2.y) and round(q2.x*q1.y) directly. (round means
// round to nearest floating point).
if (q1.x * q2.y >= q2.x * q1.y)
m--;
else
break;
}
// Using double also helps, but float can solve the issue for now.
// while (m > 1 && cross_2d<T, double>(q[i] - q[m - 2], q[m - 1] - q[m - 2])
// >= 0) {
// m--;
// }
q[m++] = q[i];
}
// Step 6 (Optional):
// In general sense we need the original coordinates, so we
// need to shift the points back (reverting Step 2)
// But if we're only interested in getting the area/perimeter of the shape
// We can simply return.
if (!shift_to_zero) {
for (int i = 0; i < m; i++) {
q[i] += start;
}
}
return m;
}
template <typename T>
HOST_DEVICE_INLINE T polygon_area(const Point<T> (&q)[24], const int& m) {
if (m <= 2) {
return 0;
}
T area = 0;
for (int i = 1; i < m - 1; i++) {
area += fabs(cross_2d<T>(q[i] - q[0], q[i + 1] - q[0]));
}
return area / 2.0;
}
template <typename T>
HOST_DEVICE_INLINE T rotated_boxes_intersection(
const RotatedBox<T>& box1,
const RotatedBox<T>& box2) {
// There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned
// from rotated_rect_intersection_pts
Point<T> intersectPts[24], orderedPts[24];
Point<T> pts1[4];
Point<T> pts2[4];
get_rotated_vertices<T>(box1, pts1);
get_rotated_vertices<T>(box2, pts2);
int num = get_intersection_points<T>(pts1, pts2, intersectPts);
if (num <= 2) {
return 0.0;
}
// Convex Hull to order the intersection points in clockwise order and find
// the contour area.
int num_convex = convex_hull_graham<T>(intersectPts, num, orderedPts, true);
return polygon_area<T>(orderedPts, num_convex);
}
} // namespace
template <typename T>
HOST_DEVICE_INLINE T
single_box_iou_rotated(T const* const box1_raw, T const* const box2_raw) {
// shift center to the middle point to achieve higher precision in result
RotatedBox<T> box1, box2;
auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0;
auto center_shift_y = (box1_raw[1] + box2_raw[1]) / 2.0;
box1.x_ctr = box1_raw[0] - center_shift_x;
box1.y_ctr = box1_raw[1] - center_shift_y;
box1.w = box1_raw[2];
box1.h = box1_raw[3];
box1.a = box1_raw[4];
box2.x_ctr = box2_raw[0] - center_shift_x;
box2.y_ctr = box2_raw[1] - center_shift_y;
box2.w = box2_raw[2];
box2.h = box2_raw[3];
box2.a = box2_raw[4];
T area1 = box1.w * box1.h;
T area2 = box2.w * box2.h;
if (area1 < 1e-14 || area2 < 1e-14) {
return 0.f;
}
T intersection = rotated_boxes_intersection<T>(box1, box2);
T iou = intersection / (area1 + area2 - intersection);
return iou;
}
} // namespace detectron2
// Copyright (c) Facebook, Inc. and its affiliates.
#include "cocoeval.h"
#include <time.h>
#include <algorithm>
#include <cstdint>
#include <numeric>
using namespace pybind11::literals;
namespace detectron2 {
namespace COCOeval {
// Sort detections from highest score to lowest, such that
// detection_instances[detection_sorted_indices[t]] >=
// detection_instances[detection_sorted_indices[t+1]]. Use stable_sort to match
// original COCO API
void SortInstancesByDetectionScore(
const std::vector<InstanceAnnotation>& detection_instances,
std::vector<uint64_t>* detection_sorted_indices) {
detection_sorted_indices->resize(detection_instances.size());
std::iota(
detection_sorted_indices->begin(), detection_sorted_indices->end(), 0);
std::stable_sort(
detection_sorted_indices->begin(),
detection_sorted_indices->end(),
[&detection_instances](size_t j1, size_t j2) {
return detection_instances[j1].score > detection_instances[j2].score;
});
}
// Partition the ground truth objects based on whether or not to ignore them
// based on area
void SortInstancesByIgnore(
const std::array<double, 2>& area_range,
const std::vector<InstanceAnnotation>& ground_truth_instances,
std::vector<uint64_t>* ground_truth_sorted_indices,
std::vector<bool>* ignores) {
ignores->clear();
ignores->reserve(ground_truth_instances.size());
for (auto o : ground_truth_instances) {
ignores->push_back(
o.ignore || o.area < area_range[0] || o.area > area_range[1]);
}
ground_truth_sorted_indices->resize(ground_truth_instances.size());
std::iota(
ground_truth_sorted_indices->begin(),
ground_truth_sorted_indices->end(),
0);
std::stable_sort(
ground_truth_sorted_indices->begin(),
ground_truth_sorted_indices->end(),
[&ignores](size_t j1, size_t j2) {
return (int)(*ignores)[j1] < (int)(*ignores)[j2];
});
}
// For each IOU threshold, greedily match each detected instance to a ground
// truth instance (if possible) and store the results
void MatchDetectionsToGroundTruth(
const std::vector<InstanceAnnotation>& detection_instances,
const std::vector<uint64_t>& detection_sorted_indices,
const std::vector<InstanceAnnotation>& ground_truth_instances,
const std::vector<uint64_t>& ground_truth_sorted_indices,
const std::vector<bool>& ignores,
const std::vector<std::vector<double>>& ious,
const std::vector<double>& iou_thresholds,
const std::array<double, 2>& area_range,
ImageEvaluation* results) {
// Initialize memory to store return data matches and ignore
const int num_iou_thresholds = iou_thresholds.size();
const int num_ground_truth = ground_truth_sorted_indices.size();
const int num_detections = detection_sorted_indices.size();
std::vector<uint64_t> ground_truth_matches(
num_iou_thresholds * num_ground_truth, 0);
std::vector<uint64_t>& detection_matches = results->detection_matches;
std::vector<bool>& detection_ignores = results->detection_ignores;
std::vector<bool>& ground_truth_ignores = results->ground_truth_ignores;
detection_matches.resize(num_iou_thresholds * num_detections, 0);
detection_ignores.resize(num_iou_thresholds * num_detections, false);
ground_truth_ignores.resize(num_ground_truth);
for (auto g = 0; g < num_ground_truth; ++g) {
ground_truth_ignores[g] = ignores[ground_truth_sorted_indices[g]];
}
for (auto t = 0; t < num_iou_thresholds; ++t) {
for (auto d = 0; d < num_detections; ++d) {
// information about best match so far (match=-1 -> unmatched)
double best_iou = std::min(iou_thresholds[t], 1 - 1e-10);
int match = -1;
for (auto g = 0; g < num_ground_truth; ++g) {
// if this ground truth instance is already matched and not a
// crowd, it cannot be matched to another detection
if (ground_truth_matches[t * num_ground_truth + g] > 0 &&
!ground_truth_instances[ground_truth_sorted_indices[g]].is_crowd) {
continue;
}
// if detected instance matched to a regular ground truth
// instance, we can break on the first ground truth instance
// tagged as ignore (because they are sorted by the ignore tag)
if (match >= 0 && !ground_truth_ignores[match] &&
ground_truth_ignores[g]) {
break;
}
// if IOU overlap is the best so far, store the match appropriately
if (ious[d][ground_truth_sorted_indices[g]] >= best_iou) {
best_iou = ious[d][ground_truth_sorted_indices[g]];
match = g;
}
}
// if match was made, store id of match for both detection and
// ground truth
if (match >= 0) {
detection_ignores[t * num_detections + d] = ground_truth_ignores[match];
detection_matches[t * num_detections + d] =
ground_truth_instances[ground_truth_sorted_indices[match]].id;
ground_truth_matches[t * num_ground_truth + match] =
detection_instances[detection_sorted_indices[d]].id;
}
// set unmatched detections outside of area range to ignore
const InstanceAnnotation& detection =
detection_instances[detection_sorted_indices[d]];
detection_ignores[t * num_detections + d] =
detection_ignores[t * num_detections + d] ||
(detection_matches[t * num_detections + d] == 0 &&
(detection.area < area_range[0] || detection.area > area_range[1]));
}
}
// store detection score results
results->detection_scores.resize(detection_sorted_indices.size());
for (size_t d = 0; d < detection_sorted_indices.size(); ++d) {
results->detection_scores[d] =
detection_instances[detection_sorted_indices[d]].score;
}
}
std::vector<ImageEvaluation> EvaluateImages(
const std::vector<std::array<double, 2>>& area_ranges,
int max_detections,
const std::vector<double>& iou_thresholds,
const ImageCategoryInstances<std::vector<double>>& image_category_ious,
const ImageCategoryInstances<InstanceAnnotation>&
image_category_ground_truth_instances,
const ImageCategoryInstances<InstanceAnnotation>&
image_category_detection_instances) {
const int num_area_ranges = area_ranges.size();
const int num_images = image_category_ground_truth_instances.size();
const int num_categories =
image_category_ious.size() > 0 ? image_category_ious[0].size() : 0;
std::vector<uint64_t> detection_sorted_indices;
std::vector<uint64_t> ground_truth_sorted_indices;
std::vector<bool> ignores;
std::vector<ImageEvaluation> results_all(
num_images * num_area_ranges * num_categories);
// Store results for each image, category, and area range combination. Results
// for each IOU threshold are packed into the same ImageEvaluation object
for (auto i = 0; i < num_images; ++i) {
for (auto c = 0; c < num_categories; ++c) {
const std::vector<InstanceAnnotation>& ground_truth_instances =
image_category_ground_truth_instances[i][c];
const std::vector<InstanceAnnotation>& detection_instances =
image_category_detection_instances[i][c];
SortInstancesByDetectionScore(
detection_instances, &detection_sorted_indices);
if ((int)detection_sorted_indices.size() > max_detections) {
detection_sorted_indices.resize(max_detections);
}
for (size_t a = 0; a < area_ranges.size(); ++a) {
SortInstancesByIgnore(
area_ranges[a],
ground_truth_instances,
&ground_truth_sorted_indices,
&ignores);
MatchDetectionsToGroundTruth(
detection_instances,
detection_sorted_indices,
ground_truth_instances,
ground_truth_sorted_indices,
ignores,
image_category_ious[i][c],
iou_thresholds,
area_ranges[a],
&results_all
[c * num_area_ranges * num_images + a * num_images + i]);
}
}
}
return results_all;
}
// Convert a python list to a vector
template <typename T>
std::vector<T> list_to_vec(const py::list& l) {
std::vector<T> v(py::len(l));
for (int i = 0; i < (int)py::len(l); ++i) {
v[i] = l[i].cast<T>();
}
return v;
}
// Helper function to Accumulate()
// Considers the evaluation results applicable to a particular category, area
// range, and max_detections parameter setting, which begin at
// evaluations[evaluation_index]. Extracts a sorted list of length n of all
// applicable detection instances concatenated across all images in the dataset,
// which are represented by the outputs evaluation_indices, detection_scores,
// image_detection_indices, and detection_sorted_indices--all of which are
// length n. evaluation_indices[i] stores the applicable index into
// evaluations[] for instance i, which has detection score detection_score[i],
// and is the image_detection_indices[i]'th of the list of detections
// for the image containing i. detection_sorted_indices[] defines a sorted
// permutation of the 3 other outputs
int BuildSortedDetectionList(
const std::vector<ImageEvaluation>& evaluations,
const int64_t evaluation_index,
const int64_t num_images,
const int max_detections,
std::vector<uint64_t>* evaluation_indices,
std::vector<double>* detection_scores,
std::vector<uint64_t>* detection_sorted_indices,
std::vector<uint64_t>* image_detection_indices) {
assert(evaluations.size() >= evaluation_index + num_images);
// Extract a list of object instances of the applicable category, area
// range, and max detections requirements such that they can be sorted
image_detection_indices->clear();
evaluation_indices->clear();
detection_scores->clear();
image_detection_indices->reserve(num_images * max_detections);
evaluation_indices->reserve(num_images * max_detections);
detection_scores->reserve(num_images * max_detections);
int num_valid_ground_truth = 0;
for (auto i = 0; i < num_images; ++i) {
const ImageEvaluation& evaluation = evaluations[evaluation_index + i];
for (int d = 0;
d < (int)evaluation.detection_scores.size() && d < max_detections;
++d) { // detected instances
evaluation_indices->push_back(evaluation_index + i);
image_detection_indices->push_back(d);
detection_scores->push_back(evaluation.detection_scores[d]);
}
for (auto ground_truth_ignore : evaluation.ground_truth_ignores) {
if (!ground_truth_ignore) {
++num_valid_ground_truth;
}
}
}
// Sort detections by decreasing score, using stable sort to match
// python implementation
detection_sorted_indices->resize(detection_scores->size());
std::iota(
detection_sorted_indices->begin(), detection_sorted_indices->end(), 0);
std::stable_sort(
detection_sorted_indices->begin(),
detection_sorted_indices->end(),
[&detection_scores](size_t j1, size_t j2) {
return (*detection_scores)[j1] > (*detection_scores)[j2];
});
return num_valid_ground_truth;
}
// Helper function to Accumulate()
// Compute a precision recall curve given a sorted list of detected instances
// encoded in evaluations, evaluation_indices, detection_scores,
// detection_sorted_indices, image_detection_indices (see
// BuildSortedDetectionList()). Using vectors precisions and recalls
// and temporary storage, output the results into precisions_out, recalls_out,
// and scores_out, which are large buffers containing many precion/recall curves
// for all possible parameter settings, with precisions_out_index and
// recalls_out_index defining the applicable indices to store results.
void ComputePrecisionRecallCurve(
const int64_t precisions_out_index,
const int64_t precisions_out_stride,
const int64_t recalls_out_index,
const std::vector<double>& recall_thresholds,
const int iou_threshold_index,
const int num_iou_thresholds,
const int num_valid_ground_truth,
const std::vector<ImageEvaluation>& evaluations,
const std::vector<uint64_t>& evaluation_indices,
const std::vector<double>& detection_scores,
const std::vector<uint64_t>& detection_sorted_indices,
const std::vector<uint64_t>& image_detection_indices,
std::vector<double>* precisions,
std::vector<double>* recalls,
std::vector<double>* precisions_out,
std::vector<double>* scores_out,
std::vector<double>* recalls_out) {
assert(recalls_out->size() > recalls_out_index);
// Compute precision/recall for each instance in the sorted list of detections
int64_t true_positives_sum = 0, false_positives_sum = 0;
precisions->clear();
recalls->clear();
precisions->reserve(detection_sorted_indices.size());
recalls->reserve(detection_sorted_indices.size());
assert(!evaluations.empty() || detection_sorted_indices.empty());
for (auto detection_sorted_index : detection_sorted_indices) {
const ImageEvaluation& evaluation =
evaluations[evaluation_indices[detection_sorted_index]];
const auto num_detections =
evaluation.detection_matches.size() / num_iou_thresholds;
const auto detection_index = iou_threshold_index * num_detections +
image_detection_indices[detection_sorted_index];
assert(evaluation.detection_matches.size() > detection_index);
assert(evaluation.detection_ignores.size() > detection_index);
const int64_t detection_match =
evaluation.detection_matches[detection_index];
const bool detection_ignores =
evaluation.detection_ignores[detection_index];
const auto true_positive = detection_match > 0 && !detection_ignores;
const auto false_positive = detection_match == 0 && !detection_ignores;
if (true_positive) {
++true_positives_sum;
}
if (false_positive) {
++false_positives_sum;
}
const double recall =
static_cast<double>(true_positives_sum) / num_valid_ground_truth;
recalls->push_back(recall);
const int64_t num_valid_detections =
true_positives_sum + false_positives_sum;
const double precision = num_valid_detections > 0
? static_cast<double>(true_positives_sum) / num_valid_detections
: 0.0;
precisions->push_back(precision);
}
(*recalls_out)[recalls_out_index] = !recalls->empty() ? recalls->back() : 0;
for (int64_t i = static_cast<int64_t>(precisions->size()) - 1; i > 0; --i) {
if ((*precisions)[i] > (*precisions)[i - 1]) {
(*precisions)[i - 1] = (*precisions)[i];
}
}
// Sample the per instance precision/recall list at each recall threshold
for (size_t r = 0; r < recall_thresholds.size(); ++r) {
// first index in recalls >= recall_thresholds[r]
std::vector<double>::iterator low = std::lower_bound(
recalls->begin(), recalls->end(), recall_thresholds[r]);
size_t precisions_index = low - recalls->begin();
const auto results_ind = precisions_out_index + r * precisions_out_stride;
assert(results_ind < precisions_out->size());
assert(results_ind < scores_out->size());
if (precisions_index < precisions->size()) {
(*precisions_out)[results_ind] = (*precisions)[precisions_index];
(*scores_out)[results_ind] =
detection_scores[detection_sorted_indices[precisions_index]];
} else {
(*precisions_out)[results_ind] = 0;
(*scores_out)[results_ind] = 0;
}
}
}
py::dict Accumulate(
const py::object& params,
const std::vector<ImageEvaluation>& evaluations) {
const std::vector<double> recall_thresholds =
list_to_vec<double>(params.attr("recThrs"));
const std::vector<int> max_detections =
list_to_vec<int>(params.attr("maxDets"));
const int num_iou_thresholds = py::len(params.attr("iouThrs"));
const int num_recall_thresholds = py::len(params.attr("recThrs"));
const int num_categories = params.attr("useCats").cast<int>() == 1
? py::len(params.attr("catIds"))
: 1;
const int num_area_ranges = py::len(params.attr("areaRng"));
const int num_max_detections = py::len(params.attr("maxDets"));
const int num_images = py::len(params.attr("imgIds"));
std::vector<double> precisions_out(
num_iou_thresholds * num_recall_thresholds * num_categories *
num_area_ranges * num_max_detections,
-1);
std::vector<double> recalls_out(
num_iou_thresholds * num_categories * num_area_ranges *
num_max_detections,
-1);
std::vector<double> scores_out(
num_iou_thresholds * num_recall_thresholds * num_categories *
num_area_ranges * num_max_detections,
-1);
// Consider the list of all detected instances in the entire dataset in one
// large list. evaluation_indices, detection_scores,
// image_detection_indices, and detection_sorted_indices all have the same
// length as this list, such that each entry corresponds to one detected
// instance
std::vector<uint64_t> evaluation_indices; // indices into evaluations[]
std::vector<double> detection_scores; // detection scores of each instance
std::vector<uint64_t> detection_sorted_indices; // sorted indices of all
// instances in the dataset
std::vector<uint64_t>
image_detection_indices; // indices into the list of detected instances in
// the same image as each instance
std::vector<double> precisions, recalls;
for (auto c = 0; c < num_categories; ++c) {
for (auto a = 0; a < num_area_ranges; ++a) {
for (auto m = 0; m < num_max_detections; ++m) {
// The COCO PythonAPI assumes evaluations[] (the return value of
// COCOeval::EvaluateImages() is one long list storing results for each
// combination of category, area range, and image id, with categories in
// the outermost loop and images in the innermost loop.
const int64_t evaluations_index =
c * num_area_ranges * num_images + a * num_images;
int num_valid_ground_truth = BuildSortedDetectionList(
evaluations,
evaluations_index,
num_images,
max_detections[m],
&evaluation_indices,
&detection_scores,
&detection_sorted_indices,
&image_detection_indices);
if (num_valid_ground_truth == 0) {
continue;
}
for (auto t = 0; t < num_iou_thresholds; ++t) {
// recalls_out is a flattened vectors representing a
// num_iou_thresholds X num_categories X num_area_ranges X
// num_max_detections matrix
const int64_t recalls_out_index =
t * num_categories * num_area_ranges * num_max_detections +
c * num_area_ranges * num_max_detections +
a * num_max_detections + m;
// precisions_out and scores_out are flattened vectors
// representing a num_iou_thresholds X num_recall_thresholds X
// num_categories X num_area_ranges X num_max_detections matrix
const int64_t precisions_out_stride =
num_categories * num_area_ranges * num_max_detections;
const int64_t precisions_out_index = t * num_recall_thresholds *
num_categories * num_area_ranges * num_max_detections +
c * num_area_ranges * num_max_detections +
a * num_max_detections + m;
ComputePrecisionRecallCurve(
precisions_out_index,
precisions_out_stride,
recalls_out_index,
recall_thresholds,
t,
num_iou_thresholds,
num_valid_ground_truth,
evaluations,
evaluation_indices,
detection_scores,
detection_sorted_indices,
image_detection_indices,
&precisions,
&recalls,
&precisions_out,
&scores_out,
&recalls_out);
}
}
}
}
time_t rawtime;
struct tm local_time;
std::array<char, 200> buffer;
time(&rawtime);
#ifdef _WIN32
localtime_s(&local_time, &rawtime);
#else
localtime_r(&rawtime, &local_time);
#endif
strftime(
buffer.data(), 200, "%Y-%m-%d %H:%num_max_detections:%S", &local_time);
return py::dict(
"params"_a = params,
"counts"_a = std::vector<int64_t>(
{num_iou_thresholds,
num_recall_thresholds,
num_categories,
num_area_ranges,
num_max_detections}),
"date"_a = buffer,
"precision"_a = precisions_out,
"recall"_a = recalls_out,
"scores"_a = scores_out);
}
} // namespace COCOeval
} // namespace detectron2
// Copyright (c) Facebook, Inc. and its affiliates.
#pragma once
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include <vector>
namespace py = pybind11;
namespace detectron2 {
namespace COCOeval {
// Annotation data for a single object instance in an image
struct InstanceAnnotation {
InstanceAnnotation(
uint64_t id,
double score,
double area,
bool is_crowd,
bool ignore)
: id{id}, score{score}, area{area}, is_crowd{is_crowd}, ignore{ignore} {}
uint64_t id;
double score = 0.;
double area = 0.;
bool is_crowd = false;
bool ignore = false;
};
// Stores intermediate results for evaluating detection results for a single
// image that has D detected instances and G ground truth instances. This stores
// matches between detected and ground truth instances
struct ImageEvaluation {
// For each of the D detected instances, the id of the matched ground truth
// instance, or 0 if unmatched
std::vector<uint64_t> detection_matches;
// The detection score of each of the D detected instances
std::vector<double> detection_scores;
// Marks whether or not each of G instances was ignored from evaluation (e.g.,
// because it's outside area_range)
std::vector<bool> ground_truth_ignores;
// Marks whether or not each of D instances was ignored from evaluation (e.g.,
// because it's outside aRng)
std::vector<bool> detection_ignores;
};
template <class T>
using ImageCategoryInstances = std::vector<std::vector<std::vector<T>>>;
// C++ implementation of COCO API cocoeval.py::COCOeval.evaluateImg(). For each
// combination of image, category, area range settings, and IOU thresholds to
// evaluate, it matches detected instances to ground truth instances and stores
// the results into a vector of ImageEvaluation results, which will be
// interpreted by the COCOeval::Accumulate() function to produce precion-recall
// curves. The parameters of nested vectors have the following semantics:
// image_category_ious[i][c][d][g] is the intersection over union of the d'th
// detected instance and g'th ground truth instance of
// category category_ids[c] in image image_ids[i]
// image_category_ground_truth_instances[i][c] is a vector of ground truth
// instances in image image_ids[i] of category category_ids[c]
// image_category_detection_instances[i][c] is a vector of detected
// instances in image image_ids[i] of category category_ids[c]
std::vector<ImageEvaluation> EvaluateImages(
const std::vector<std::array<double, 2>>& area_ranges, // vector of 2-tuples
int max_detections,
const std::vector<double>& iou_thresholds,
const ImageCategoryInstances<std::vector<double>>& image_category_ious,
const ImageCategoryInstances<InstanceAnnotation>&
image_category_ground_truth_instances,
const ImageCategoryInstances<InstanceAnnotation>&
image_category_detection_instances);
// C++ implementation of COCOeval.accumulate(), which generates precision
// recall curves for each set of category, IOU threshold, detection area range,
// and max number of detections parameters. It is assumed that the parameter
// evaluations is the return value of the functon COCOeval::EvaluateImages(),
// which was called with the same parameter settings params
py::dict Accumulate(
const py::object& params,
const std::vector<ImageEvaluation>& evalutations);
} // namespace COCOeval
} // namespace detectron2
// Copyright (c) Facebook, Inc. and its affiliates.
#include <cuda_runtime_api.h>
namespace detectron2 {
int get_cudart_version() {
// Not a ROCM platform: Either HIP is not used, or
// it is used, but platform is not ROCM (i.e. it is CUDA)
#if !defined(__HIP_PLATFORM_HCC__)
return CUDART_VERSION;
#else
int version = 0;
#if HIP_VERSION_MAJOR != 0
// Create a convention similar to that of CUDA, as assumed by other
// parts of the code.
version = HIP_VERSION_MINOR;
version += (HIP_VERSION_MAJOR * 100);
#else
hipRuntimeGetVersion(&version);
#endif
return version;
#endif
}
} // namespace detectron2
// Copyright (c) Facebook, Inc. and its affiliates.
#pragma once
#include <torch/types.h>
namespace detectron2 {
#if defined(WITH_CUDA) || defined(WITH_HIP)
int deform_conv_forward_cuda(
at::Tensor input,
at::Tensor weight,
at::Tensor offset,
at::Tensor output,
at::Tensor columns,
at::Tensor ones,
int kW,
int kH,
int dW,
int dH,
int padW,
int padH,
int dilationW,
int dilationH,
int group,
int deformable_group,
int im2col_step);
int deform_conv_backward_input_cuda(
at::Tensor input,
at::Tensor offset,
at::Tensor gradOutput,
at::Tensor gradInput,
at::Tensor gradOffset,
at::Tensor weight,
at::Tensor columns,
int kW,
int kH,
int dW,
int dH,
int padW,
int padH,
int dilationW,
int dilationH,
int group,
int deformable_group,
int im2col_step);
int deform_conv_backward_parameters_cuda(
at::Tensor input,
at::Tensor offset,
at::Tensor gradOutput,
at::Tensor gradWeight, // at::Tensor gradBias,
at::Tensor columns,
at::Tensor ones,
int kW,
int kH,
int dW,
int dH,
int padW,
int padH,
int dilationW,
int dilationH,
int group,
int deformable_group,
float scale,
int im2col_step);
void modulated_deform_conv_cuda_forward(
at::Tensor input,
at::Tensor weight,
at::Tensor bias,
at::Tensor ones,
at::Tensor offset,
at::Tensor mask,
at::Tensor output,
at::Tensor columns,
int kernel_h,
int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const int group,
const int deformable_group,
const bool with_bias);
void modulated_deform_conv_cuda_backward(
at::Tensor input,
at::Tensor weight,
at::Tensor bias,
at::Tensor ones,
at::Tensor offset,
at::Tensor mask,
at::Tensor columns,
at::Tensor grad_input,
at::Tensor grad_weight,
at::Tensor grad_bias,
at::Tensor grad_offset,
at::Tensor grad_mask,
at::Tensor grad_output,
int kernel_h,
int kernel_w,
int stride_h,
int stride_w,
int pad_h,
int pad_w,
int dilation_h,
int dilation_w,
int group,
int deformable_group,
const bool with_bias);
#endif
inline int deform_conv_forward(
at::Tensor input,
at::Tensor weight,
at::Tensor offset,
at::Tensor output,
at::Tensor columns,
at::Tensor ones,
int kW,
int kH,
int dW,
int dH,
int padW,
int padH,
int dilationW,
int dilationH,
int group,
int deformable_group,
int im2col_step) {
if (input.is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
TORCH_CHECK(weight.is_cuda(), "weight tensor is not on GPU!");
TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!");
return deform_conv_forward_cuda(
input,
weight,
offset,
output,
columns,
ones,
kW,
kH,
dW,
dH,
padW,
padH,
dilationW,
dilationH,
group,
deformable_group,
im2col_step);
#else
AT_ERROR("Detectron2 is not compiled with GPU support!");
#endif
}
AT_ERROR("This operator is not implemented on CPU");
}
inline int deform_conv_backward_input(
at::Tensor input,
at::Tensor offset,
at::Tensor gradOutput,
at::Tensor gradInput,
at::Tensor gradOffset,
at::Tensor weight,
at::Tensor columns,
int kW,
int kH,
int dW,
int dH,
int padW,
int padH,
int dilationW,
int dilationH,
int group,
int deformable_group,
int im2col_step) {
if (gradOutput.is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
TORCH_CHECK(input.is_cuda(), "input tensor is not on GPU!");
TORCH_CHECK(weight.is_cuda(), "weight tensor is not on GPU!");
TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!");
return deform_conv_backward_input_cuda(
input,
offset,
gradOutput,
gradInput,
gradOffset,
weight,
columns,
kW,
kH,
dW,
dH,
padW,
padH,
dilationW,
dilationH,
group,
deformable_group,
im2col_step);
#else
AT_ERROR("Detectron2 is not compiled with GPU support!");
#endif
}
AT_ERROR("This operator is not implemented on CPU");
}
inline int deform_conv_backward_filter(
at::Tensor input,
at::Tensor offset,
at::Tensor gradOutput,
at::Tensor gradWeight, // at::Tensor gradBias,
at::Tensor columns,
at::Tensor ones,
int kW,
int kH,
int dW,
int dH,
int padW,
int padH,
int dilationW,
int dilationH,
int group,
int deformable_group,
float scale,
int im2col_step) {
if (gradOutput.is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
TORCH_CHECK(input.is_cuda(), "input tensor is not on GPU!");
TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!");
return deform_conv_backward_parameters_cuda(
input,
offset,
gradOutput,
gradWeight,
columns,
ones,
kW,
kH,
dW,
dH,
padW,
padH,
dilationW,
dilationH,
group,
deformable_group,
scale,
im2col_step);
#else
AT_ERROR("Detectron2 is not compiled with GPU support!");
#endif
}
AT_ERROR("This operator is not implemented on CPU");
}
inline void modulated_deform_conv_forward(
at::Tensor input,
at::Tensor weight,
at::Tensor bias,
at::Tensor ones,
at::Tensor offset,
at::Tensor mask,
at::Tensor output,
at::Tensor columns,
int kernel_h,
int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const int group,
const int deformable_group,
const bool with_bias) {
if (input.is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
TORCH_CHECK(weight.is_cuda(), "weight tensor is not on GPU!");
TORCH_CHECK(bias.is_cuda(), "bias tensor is not on GPU!");
TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!");
return modulated_deform_conv_cuda_forward(
input,
weight,
bias,
ones,
offset,
mask,
output,
columns,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
group,
deformable_group,
with_bias);
#else
AT_ERROR("Detectron2 is not compiled with GPU support!");
#endif
}
AT_ERROR("This operator is not implemented on CPU");
}
inline void modulated_deform_conv_backward(
at::Tensor input,
at::Tensor weight,
at::Tensor bias,
at::Tensor ones,
at::Tensor offset,
at::Tensor mask,
at::Tensor columns,
at::Tensor grad_input,
at::Tensor grad_weight,
at::Tensor grad_bias,
at::Tensor grad_offset,
at::Tensor grad_mask,
at::Tensor grad_output,
int kernel_h,
int kernel_w,
int stride_h,
int stride_w,
int pad_h,
int pad_w,
int dilation_h,
int dilation_w,
int group,
int deformable_group,
const bool with_bias) {
if (grad_output.is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
TORCH_CHECK(input.is_cuda(), "input tensor is not on GPU!");
TORCH_CHECK(weight.is_cuda(), "weight tensor is not on GPU!");
TORCH_CHECK(bias.is_cuda(), "bias tensor is not on GPU!");
TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!");
return modulated_deform_conv_cuda_backward(
input,
weight,
bias,
ones,
offset,
mask,
columns,
grad_input,
grad_weight,
grad_bias,
grad_offset,
grad_mask,
grad_output,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
group,
deformable_group,
with_bias);
#else
AT_ERROR("Detectron2 is not compiled with GPU support!");
#endif
}
AT_ERROR("This operator is not implemented on CPU");
}
} // namespace detectron2
// Copyright (c) Facebook, Inc. and its affiliates.
// modified from
// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp
// Original license: Apache 2.0
// modify from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
// Original license: Apache 2.0
#include <torch/types.h>
#include "deform_conv.h"
#include <cmath>
#include <vector>
namespace detectron2 {
void deformable_im2col(
const at::Tensor data_im,
const at::Tensor data_offset,
const int channels,
const int height,
const int width,
const int ksize_h,
const int ksize_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int parallel_imgs,
const int deformable_group,
at::Tensor data_col);
void deformable_col2im(
const at::Tensor data_col,
const at::Tensor data_offset,
const int channels,
const int height,
const int width,
const int ksize_h,
const int ksize_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int parallel_imgs,
const int deformable_group,
at::Tensor grad_im);
void deformable_col2im_coord(
const at::Tensor data_col,
const at::Tensor data_im,
const at::Tensor data_offset,
const int channels,
const int height,
const int width,
const int ksize_h,
const int ksize_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int parallel_imgs,
const int deformable_group,
at::Tensor grad_offset);
void modulated_deformable_im2col_cuda(
const at::Tensor data_im,
const at::Tensor data_offset,
const at::Tensor data_mask,
const int batch_size,
const int channels,
const int height_im,
const int width_im,
const int height_col,
const int width_col,
const int kernel_h,
const int kenerl_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int deformable_group,
at::Tensor data_col);
void modulated_deformable_col2im_cuda(
const at::Tensor data_col,
const at::Tensor data_offset,
const at::Tensor data_mask,
const int batch_size,
const int channels,
const int height_im,
const int width_im,
const int height_col,
const int width_col,
const int kernel_h,
const int kenerl_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int deformable_group,
at::Tensor grad_im);
void modulated_deformable_col2im_coord_cuda(
const at::Tensor data_col,
const at::Tensor data_im,
const at::Tensor data_offset,
const at::Tensor data_mask,
const int batch_size,
const int channels,
const int height_im,
const int width_im,
const int height_col,
const int width_col,
const int kernel_h,
const int kenerl_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int deformable_group,
at::Tensor grad_offset,
at::Tensor grad_mask);
void shape_check(
at::Tensor input,
at::Tensor offset,
at::Tensor* gradOutput,
at::Tensor weight,
int kH,
int kW,
int dH,
int dW,
int padH,
int padW,
int dilationH,
int dilationW,
int group,
int deformable_group) {
TORCH_CHECK(
weight.ndimension() == 4,
"4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
"but got: %s",
weight.ndimension());
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
TORCH_CHECK(
kW > 0 && kH > 0,
"kernel size should be greater than zero, but got kH: %d kW: %d",
kH,
kW);
TORCH_CHECK(
(weight.size(2) == kH && weight.size(3) == kW),
"kernel size should be consistent with weight, ",
"but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d",
kH,
kW,
weight.size(2),
weight.size(3));
TORCH_CHECK(
dW > 0 && dH > 0,
"stride should be greater than zero, but got dH: %d dW: %d",
dH,
dW);
TORCH_CHECK(
dilationW > 0 && dilationH > 0,
"dilation should be greater than 0, but got dilationH: %d dilationW: %d",
dilationH,
dilationW);
int ndim = input.ndimension();
int dimf = 0;
int dimh = 1;
int dimw = 2;
if (ndim == 4) {
dimf++;
dimh++;
dimw++;
}
TORCH_CHECK(
ndim == 3 || ndim == 4,
"3D or 4D input tensor expected but got: %s",
ndim);
long nInputPlane = weight.size(1) * group;
long inputHeight = input.size(dimh);
long inputWidth = input.size(dimw);
long nOutputPlane = weight.size(0);
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
TORCH_CHECK(
nInputPlane % deformable_group == 0,
"input channels must divide deformable group size");
if (outputWidth < 1 || outputHeight < 1)
AT_ERROR(
"Given input size: (%ld x %ld x %ld). "
"Calculated output size: (%ld x %ld x %ld). Output size is too small",
nInputPlane,
inputHeight,
inputWidth,
nOutputPlane,
outputHeight,
outputWidth);
TORCH_CHECK(
input.size(1) == nInputPlane,
"invalid number of input planes, expected: %d, but got: %d",
nInputPlane,
input.size(1));
TORCH_CHECK(
(inputHeight + 2 * padH >= kH && inputWidth + 2 * padW >= kW),
"input image is smaller than kernel");
TORCH_CHECK(
(offset.size(2) == outputHeight && offset.size(3) == outputWidth),
"invalid spatial size of offset, expected height: %d width: %d, but "
"got height: %d width: %d",
outputHeight,
outputWidth,
offset.size(2),
offset.size(3));
TORCH_CHECK(
(offset.size(1) == deformable_group * 2 * kH * kW),
"invalid number of channels of offset");
if (gradOutput != NULL) {
TORCH_CHECK(
gradOutput->size(dimf) == nOutputPlane,
"invalid number of gradOutput planes, expected: %d, but got: %d",
nOutputPlane,
gradOutput->size(dimf));
TORCH_CHECK(
(gradOutput->size(dimh) == outputHeight &&
gradOutput->size(dimw) == outputWidth),
"invalid size of gradOutput, expected height: %d width: %d , but "
"got height: %d width: %d",
outputHeight,
outputWidth,
gradOutput->size(dimh),
gradOutput->size(dimw));
}
}
int deform_conv_forward_cuda(
at::Tensor input,
at::Tensor weight,
at::Tensor offset,
at::Tensor output,
at::Tensor columns,
at::Tensor ones,
int kW,
int kH,
int dW,
int dH,
int padW,
int padH,
int dilationW,
int dilationH,
int group,
int deformable_group,
int im2col_step) {
// todo: resize columns to include im2col: done
// todo: add im2col_step as input
// todo: add new output buffer and transpose it to output (or directly
// transpose output) todo: possibly change data indexing because of
// parallel_imgs
shape_check(
input,
offset,
NULL,
weight,
kH,
kW,
dH,
dW,
padH,
padW,
dilationH,
dilationW,
group,
deformable_group);
input = input.contiguous();
offset = offset.contiguous();
weight = weight.contiguous();
int batch = 1;
if (input.ndimension() == 3) {
// Force batch
batch = 0;
input.unsqueeze_(0);
offset.unsqueeze_(0);
}
// todo: assert batchsize dividable by im2col_step
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputHeight = input.size(2);
long inputWidth = input.size(3);
long nOutputPlane = weight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
output = output.view(
{batchSize / im2col_step,
im2col_step,
nOutputPlane,
outputHeight,
outputWidth});
columns = at::zeros(
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
input.options());
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
ones = at::ones({outputHeight, outputWidth}, input.options());
}
input = input.view(
{batchSize / im2col_step,
im2col_step,
nInputPlane,
inputHeight,
inputWidth});
offset = offset.view(
{batchSize / im2col_step,
im2col_step,
deformable_group * 2 * kH * kW,
outputHeight,
outputWidth});
at::Tensor output_buffer = at::zeros(
{batchSize / im2col_step,
nOutputPlane,
im2col_step * outputHeight,
outputWidth},
output.options());
output_buffer = output_buffer.view(
{output_buffer.size(0),
group,
output_buffer.size(1) / group,
output_buffer.size(2),
output_buffer.size(3)});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
deformable_im2col(
input[elt],
offset[elt],
nInputPlane,
inputHeight,
inputWidth,
kH,
kW,
padH,
padW,
dH,
dW,
dilationH,
dilationW,
im2col_step,
deformable_group,
columns);
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view(
{group,
weight.size(0) / group,
weight.size(1),
weight.size(2),
weight.size(3)});
for (int g = 0; g < group; g++) {
output_buffer[elt][g] = output_buffer[elt][g]
.flatten(1)
.addmm_(weight[g].flatten(1), columns[g])
.view_as(output_buffer[elt][g]);
}
}
output_buffer = output_buffer.view(
{output_buffer.size(0),
output_buffer.size(1) * output_buffer.size(2),
output_buffer.size(3),
output_buffer.size(4)});
output_buffer = output_buffer.view(
{batchSize / im2col_step,
nOutputPlane,
im2col_step,
outputHeight,
outputWidth});
output_buffer.transpose_(1, 2);
output.copy_(output_buffer);
output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
if (batch == 0) {
output = output.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
}
return 1;
}
int deform_conv_backward_input_cuda(
at::Tensor input,
at::Tensor offset,
at::Tensor gradOutput,
at::Tensor gradInput,
at::Tensor gradOffset,
at::Tensor weight,
at::Tensor columns,
int kW,
int kH,
int dW,
int dH,
int padW,
int padH,
int dilationW,
int dilationH,
int group,
int deformable_group,
int im2col_step) {
shape_check(
input,
offset,
&gradOutput,
weight,
kH,
kW,
dH,
dW,
padH,
padW,
dilationH,
dilationW,
group,
deformable_group);
input = input.contiguous();
offset = offset.contiguous();
gradOutput = gradOutput.contiguous();
weight = weight.contiguous();
int batch = 1;
if (input.ndimension() == 3) {
// Force batch
batch = 0;
input = input.view({1, input.size(0), input.size(1), input.size(2)});
offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
gradOutput = gradOutput.view(
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
}
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputHeight = input.size(2);
long inputWidth = input.size(3);
long nOutputPlane = weight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
columns = at::zeros(
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
input.options());
// change order of grad output
gradOutput = gradOutput.view(
{batchSize / im2col_step,
im2col_step,
nOutputPlane,
outputHeight,
outputWidth});
gradOutput.transpose_(1, 2);
gradInput = gradInput.view(
{batchSize / im2col_step,
im2col_step,
nInputPlane,
inputHeight,
inputWidth});
input = input.view(
{batchSize / im2col_step,
im2col_step,
nInputPlane,
inputHeight,
inputWidth});
gradOffset = gradOffset.view(
{batchSize / im2col_step,
im2col_step,
deformable_group * 2 * kH * kW,
outputHeight,
outputWidth});
offset = offset.view(
{batchSize / im2col_step,
im2col_step,
deformable_group * 2 * kH * kW,
outputHeight,
outputWidth});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
// divide into groups
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view(
{group,
weight.size(0) / group,
weight.size(1),
weight.size(2),
weight.size(3)});
gradOutput = gradOutput.view(
{gradOutput.size(0),
group,
gradOutput.size(1) / group,
gradOutput.size(2),
gradOutput.size(3),
gradOutput.size(4)});
for (int g = 0; g < group; g++) {
columns[g] = columns[g].addmm_(
weight[g].flatten(1).transpose(0, 1),
gradOutput[elt][g].flatten(1),
0.0f,
1.0f);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
gradOutput = gradOutput.view(
{gradOutput.size(0),
gradOutput.size(1) * gradOutput.size(2),
gradOutput.size(3),
gradOutput.size(4),
gradOutput.size(5)});
deformable_col2im_coord(
columns,
input[elt],
offset[elt],
nInputPlane,
inputHeight,
inputWidth,
kH,
kW,
padH,
padW,
dH,
dW,
dilationH,
dilationW,
im2col_step,
deformable_group,
gradOffset[elt]);
deformable_col2im(
columns,
offset[elt],
nInputPlane,
inputHeight,
inputWidth,
kH,
kW,
padH,
padW,
dH,
dW,
dilationH,
dilationW,
im2col_step,
deformable_group,
gradInput[elt]);
}
gradOutput.transpose_(1, 2);
gradOutput =
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
gradOffset = gradOffset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
if (batch == 0) {
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
gradOffset =
gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
}
return 1;
}
int deform_conv_backward_parameters_cuda(
at::Tensor input,
at::Tensor offset,
at::Tensor gradOutput,
at::Tensor gradWeight, // at::Tensor gradBias,
at::Tensor columns,
at::Tensor ones,
int kW,
int kH,
int dW,
int dH,
int padW,
int padH,
int dilationW,
int dilationH,
int group,
int deformable_group,
float scale,
int im2col_step) {
// todo: transpose and reshape outGrad
// todo: reshape columns
// todo: add im2col_step as input
shape_check(
input,
offset,
&gradOutput,
gradWeight,
kH,
kW,
dH,
dW,
padH,
padW,
dilationH,
dilationW,
group,
deformable_group);
input = input.contiguous();
offset = offset.contiguous();
gradOutput = gradOutput.contiguous();
int batch = 1;
if (input.ndimension() == 3) {
// Force batch
batch = 0;
input = input.view(
at::IntList({1, input.size(0), input.size(1), input.size(2)}));
gradOutput = gradOutput.view(
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
}
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputHeight = input.size(2);
long inputWidth = input.size(3);
long nOutputPlane = gradWeight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
columns = at::zeros(
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
input.options());
gradOutput = gradOutput.view(
{batchSize / im2col_step,
im2col_step,
nOutputPlane,
outputHeight,
outputWidth});
gradOutput.transpose_(1, 2);
at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
gradOutputBuffer = gradOutputBuffer.view(
{batchSize / im2col_step,
nOutputPlane,
im2col_step,
outputHeight,
outputWidth});
gradOutputBuffer.copy_(gradOutput);
// gradOutput is not contiguous, so we do reshape (instead of view) next
gradOutputBuffer = gradOutputBuffer.reshape(
{batchSize / im2col_step,
nOutputPlane,
im2col_step * outputHeight,
outputWidth});
gradOutput.transpose_(1, 2);
gradOutput =
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
input = input.view(
{batchSize / im2col_step,
im2col_step,
nInputPlane,
inputHeight,
inputWidth});
offset = offset.view(
{batchSize / im2col_step,
im2col_step,
deformable_group * 2 * kH * kW,
outputHeight,
outputWidth});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
deformable_im2col(
input[elt],
offset[elt],
nInputPlane,
inputHeight,
inputWidth,
kH,
kW,
padH,
padW,
dH,
dW,
dilationH,
dilationW,
im2col_step,
deformable_group,
columns);
// divide into group
gradOutputBuffer = gradOutputBuffer.view(
{gradOutputBuffer.size(0),
group,
gradOutputBuffer.size(1) / group,
gradOutputBuffer.size(2),
gradOutputBuffer.size(3)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
gradWeight = gradWeight.view(
{group,
gradWeight.size(0) / group,
gradWeight.size(1),
gradWeight.size(2),
gradWeight.size(3)});
for (int g = 0; g < group; g++) {
gradWeight[g] = gradWeight[g]
.flatten(1)
.addmm_(
gradOutputBuffer[elt][g].flatten(1),
columns[g].transpose(1, 0),
1.0,
scale)
.view_as(gradWeight[g]);
}
gradOutputBuffer = gradOutputBuffer.view(
{gradOutputBuffer.size(0),
gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
gradOutputBuffer.size(3),
gradOutputBuffer.size(4)});
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
gradWeight = gradWeight.view(
{gradWeight.size(0) * gradWeight.size(1),
gradWeight.size(2),
gradWeight.size(3),
gradWeight.size(4)});
}
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
if (batch == 0) {
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
}
return 1;
}
void modulated_deform_conv_cuda_forward(
at::Tensor input,
at::Tensor weight,
at::Tensor bias,
at::Tensor ones,
at::Tensor offset,
at::Tensor mask,
at::Tensor output,
at::Tensor columns,
int kernel_h,
int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const int group,
const int deformable_group,
const bool with_bias) {
shape_check(
input,
offset,
NULL,
weight,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
group,
deformable_group);
TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR(
"Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_,
kernel_w,
kernel_h_,
kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR(
"Input shape and kernel channels wont match: (%d vs %d).",
channels,
channels_kernel * group);
const int height_out =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
// mask shape check
TORCH_CHECK(
(mask.size(2) == height_out && mask.size(3) == width_out),
"invalid spatial size of mask, expected height: %d width: %d, but "
"got height: %d width: %d",
height_out,
width_out,
mask.size(2),
mask.size(3));
TORCH_CHECK(
(mask.size(1) == deformable_group * kernel_h * kernel_w),
"invalid number of channels of mask");
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out) {
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.options());
}
// resize output
output = output.view({batch, channels_out, height_out, width_out}).zero_();
// resize temporary columns
columns = at::zeros(
{channels * kernel_h * kernel_w, 1 * height_out * width_out},
input.options());
output = output.view(
{output.size(0),
group,
output.size(1) / group,
output.size(2),
output.size(3)});
for (int b = 0; b < batch; b++) {
modulated_deformable_im2col_cuda(
input[b],
offset[b],
mask[b],
1,
channels,
height,
width,
height_out,
width_out,
kernel_h,
kernel_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
deformable_group,
columns);
// divide into group
weight = weight.view(
{group,
weight.size(0) / group,
weight.size(1),
weight.size(2),
weight.size(3)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
for (int g = 0; g < group; g++) {
output[b][g] = output[b][g]
.flatten(1)
.addmm_(weight[g].flatten(1), columns[g])
.view_as(output[b][g]);
}
weight = weight.view(
{weight.size(0) * weight.size(1),
weight.size(2),
weight.size(3),
weight.size(4)});
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
}
output = output.view(
{output.size(0),
output.size(1) * output.size(2),
output.size(3),
output.size(4)});
if (with_bias) {
output += bias.view({1, bias.size(0), 1, 1});
}
}
void modulated_deform_conv_cuda_backward(
at::Tensor input,
at::Tensor weight,
at::Tensor bias,
at::Tensor ones,
at::Tensor offset,
at::Tensor mask,
at::Tensor columns,
at::Tensor grad_input,
at::Tensor grad_weight,
at::Tensor grad_bias,
at::Tensor grad_offset,
at::Tensor grad_mask,
at::Tensor grad_output,
int kernel_h,
int kernel_w,
int stride_h,
int stride_w,
int pad_h,
int pad_w,
int dilation_h,
int dilation_w,
int group,
int deformable_group,
const bool with_bias) {
shape_check(
input,
offset,
&grad_output,
weight,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
group,
deformable_group);
TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR(
"Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_,
kernel_w,
kernel_h_,
kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR(
"Input shape and kernel channels wont match: (%d vs %d).",
channels,
channels_kernel * group);
const int height_out =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
// mask shape check
TORCH_CHECK(
(mask.size(2) == height_out && mask.size(3) == width_out),
"invalid spatial size of mask, expected height: %d width: %d, but "
"got height: %d width: %d",
height_out,
width_out,
mask.size(2),
mask.size(3));
TORCH_CHECK(
(mask.size(1) == deformable_group * kernel_h * kernel_w),
"invalid number of channels of mask");
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out) {
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.options());
}
grad_input = grad_input.view({batch, channels, height, width});
columns = at::zeros(
{channels * kernel_h * kernel_w, height_out * width_out},
input.options());
grad_output = grad_output.view(
{grad_output.size(0),
group,
grad_output.size(1) / group,
grad_output.size(2),
grad_output.size(3)});
for (int b = 0; b < batch; b++) {
// divide int group
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view(
{group,
weight.size(0) / group,
weight.size(1),
weight.size(2),
weight.size(3)});
for (int g = 0; g < group; g++) {
columns[g].addmm_(
weight[g].flatten(1).transpose(0, 1),
grad_output[b][g].flatten(1),
0.0f,
1.0f);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
weight = weight.view(
{weight.size(0) * weight.size(1),
weight.size(2),
weight.size(3),
weight.size(4)});
// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cuda(
columns,
input[b],
offset[b],
mask[b],
1,
channels,
height,
width,
height_out,
width_out,
kernel_h,
kernel_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
deformable_group,
grad_offset[b],
grad_mask[b]);
// gradient w.r.t. input data
modulated_deformable_col2im_cuda(
columns,
offset[b],
mask[b],
1,
channels,
height,
width,
height_out,
width_out,
kernel_h,
kernel_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
deformable_group,
grad_input[b]);
// gradient w.r.t. weight, dWeight should accumulate across the batch and
// group
modulated_deformable_im2col_cuda(
input[b],
offset[b],
mask[b],
1,
channels,
height,
width,
height_out,
width_out,
kernel_h,
kernel_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
deformable_group,
columns);
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
grad_weight = grad_weight.view(
{group,
grad_weight.size(0) / group,
grad_weight.size(1),
grad_weight.size(2),
grad_weight.size(3)});
if (with_bias)
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
for (int g = 0; g < group; g++) {
grad_weight[g] =
grad_weight[g]
.flatten(1)
.addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
.view_as(grad_weight[g]);
if (with_bias) {
grad_bias[g] =
grad_bias[g]
.view({-1, 1})
.addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
.view(-1);
}
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
grad_weight = grad_weight.view(
{grad_weight.size(0) * grad_weight.size(1),
grad_weight.size(2),
grad_weight.size(3),
grad_weight.size(4)});
if (with_bias)
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
}
grad_output = grad_output.view(
{grad_output.size(0) * grad_output.size(1),
grad_output.size(2),
grad_output.size(3),
grad_output.size(4)});
}
} // namespace detectron2
// Copyright (c) Facebook, Inc. and its affiliates.
// modified from
// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
// Original license: Apache 2.0
// clang-format off
// modify from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
/*!
******************* BEGIN Caffe Copyright Notice and Disclaimer *****************
*
* COPYRIGHT
*
* All contributions by the University of California:
* Copyright (c) 2014-2017 The Regents of the University of California (Regents)
* All rights reserved.
*
* All other contributions:
* Copyright (c) 2014-2017, the respective contributors
* All rights reserved.
*
* Caffe uses a shared copyright model: each contributor holds copyright over
* their contributions to Caffe. The project versioning records all such
* contribution and copyright details. If a contributor wants to further mark
* their specific copyright on a particular contribution, they should indicate
* their copyright solely in the commit message of the change when it is
* committed.
*
* LICENSE
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
*FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
*DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
*SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
*CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
*OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
*OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* CONTRIBUTION AGREEMENT
*
* By contributing to the BVLC/caffe repository through pull-request, comment,
* or otherwise, the contributor releases their content to the
* license and copyright terms herein.
*
***************** END Caffe Copyright Notice and Disclaimer *********************
*
* Copyright (c) 2018 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file modulated_deformable_im2col.cuh
* \brief Function definitions of converting an image to
* column matrix based on kernel, padding, dilation, and offset.
* These functions are mainly used in deformable convolution operators.
* \ref: https://arxiv.org/abs/1703.06211
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
*/
#include <ATen/ATen.h>
#include <c10/cuda/CUDAGuard.h>
#include <float.h>
#include <math.h>
#include <stdio.h>
#include <THC/THCAtomics.cuh>
using namespace at;
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
namespace {
const int CUDA_NUM_THREADS = 1024;
const int kMaxGridNum = 65535;
inline int GET_BLOCKS(const int N) {
return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
}
}
template <typename scalar_t>
__device__ scalar_t deformable_im2col_bilinear(
const scalar_t* bottom_data,
const int data_width,
const int height,
const int width,
scalar_t h,
scalar_t w) {
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
scalar_t lh = h - h_low;
scalar_t lw = w - w_low;
scalar_t hh = 1 - lh, hw = 1 - lw;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
v1 = bottom_data[h_low * data_width + w_low];
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = bottom_data[h_low * data_width + w_high];
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = bottom_data[h_high * data_width + w_low];
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = bottom_data[h_high * data_width + w_high];
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename scalar_t>
__device__ scalar_t get_gradient_weight(
scalar_t argmax_h,
scalar_t argmax_w,
const int h,
const int w,
const int height,
const int width) {
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
argmax_w >= width) {
// empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
scalar_t weight = 0;
if (h == argmax_h_low && w == argmax_w_low)
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
if (h == argmax_h_low && w == argmax_w_high)
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
if (h == argmax_h_high && w == argmax_w_low)
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
if (h == argmax_h_high && w == argmax_w_high)
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
return weight;
}
template <typename scalar_t>
__device__ scalar_t get_coordinate_weight(
scalar_t argmax_h,
scalar_t argmax_w,
const int height,
const int width,
const scalar_t* im_data,
const int data_width,
const int bp_dir) {
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
argmax_w >= width) {
// empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
scalar_t weight = 0;
if (bp_dir == 0) {
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_w_low + 1 - argmax_w) *
im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += -1 * (argmax_w - argmax_w_low) *
im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += (argmax_w_low + 1 - argmax_w) *
im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_w - argmax_w_low) *
im_data[argmax_h_high * data_width + argmax_w_high];
} else if (bp_dir == 1) {
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_h_low + 1 - argmax_h) *
im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += (argmax_h_low + 1 - argmax_h) *
im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += -1 * (argmax_h - argmax_h_low) *
im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_h - argmax_h_low) *
im_data[argmax_h_high * data_width + argmax_w_high];
}
return weight;
}
template <typename scalar_t>
__global__ void deformable_im2col_gpu_kernel(
const int n,
const scalar_t* data_im,
const scalar_t* data_offset,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int channel_per_deformable_group,
const int batch_size,
const int num_channels,
const int deformable_group,
const int height_col,
const int width_col,
scalar_t* data_col) {
CUDA_KERNEL_LOOP(index, n) {
// index index of output matrix
const int w_col = index % width_col;
const int h_col = (index / width_col) % height_col;
const int b_col = (index / width_col / height_col) % batch_size;
const int c_im = (index / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
// compute deformable group index
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
scalar_t* data_col_ptr = data_col +
((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
// const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) *
// height + h_in) * width + w_in;
const scalar_t* data_im_ptr =
data_im + (b_col * num_channels + c_im) * height * width;
const scalar_t* data_offset_ptr = data_offset +
(b_col * deformable_group + deformable_group_index) * 2 * kernel_h *
kernel_w * height_col * width_col;
for (int i = 0; i < kernel_h; ++i) {
for (int j = 0; j < kernel_w; ++j) {
const int data_offset_h_ptr =
((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
w_col;
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
scalar_t val = static_cast<scalar_t>(0);
const scalar_t h_im = h_in + i * dilation_h + offset_h;
const scalar_t w_im = w_in + j * dilation_w + offset_w;
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) {
// const scalar_t map_h = i * dilation_h + offset_h;
// const scalar_t map_w = j * dilation_w + offset_w;
// const int cur_height = height - h_in;
// const int cur_width = width - w_in;
// val = deformable_im2col_bilinear(data_im_ptr, width, cur_height,
// cur_width, map_h, map_w);
val = deformable_im2col_bilinear(
data_im_ptr, width, height, width, h_im, w_im);
}
*data_col_ptr = val;
data_col_ptr += batch_size * height_col * width_col;
}
}
}
}
template <typename scalar_t>
__global__ void deformable_col2im_gpu_kernel(
const int n,
const scalar_t* data_col,
const scalar_t* data_offset,
const int channels,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int channel_per_deformable_group,
const int batch_size,
const int deformable_group,
const int height_col,
const int width_col,
scalar_t* grad_im) {
CUDA_KERNEL_LOOP(index, n) {
const int j = (index / width_col / height_col / batch_size) % kernel_w;
const int i =
(index / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int c =
index / width_col / height_col / batch_size / kernel_w / kernel_h;
// compute the start and end of the output
const int deformable_group_index = c / channel_per_deformable_group;
int w_out = index % width_col;
int h_out = (index / width_col) % height_col;
int b = (index / width_col / height_col) % batch_size;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const scalar_t* data_offset_ptr = data_offset +
(b * deformable_group + deformable_group_index) * 2 * kernel_h *
kernel_w * height_col * width_col;
const int data_offset_h_ptr =
((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
const int data_offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
const scalar_t cur_top_grad = data_col[index];
const int cur_h = (int)cur_inv_h_data;
const int cur_w = (int)cur_inv_w_data;
for (int dy = -2; dy <= 2; dy++) {
for (int dx = -2; dx <= 2; dx++) {
if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 &&
cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
abs(cur_inv_w_data - (cur_w + dx)) < 1) {
int cur_bottom_grad_pos =
((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
scalar_t weight = get_gradient_weight(
cur_inv_h_data,
cur_inv_w_data,
cur_h + dy,
cur_w + dx,
height,
width);
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
}
}
}
}
}
template <typename scalar_t>
__global__ void deformable_col2im_coord_gpu_kernel(
const int n,
const scalar_t* data_col,
const scalar_t* data_im,
const scalar_t* data_offset,
const int channels,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int channel_per_deformable_group,
const int batch_size,
const int offset_channels,
const int deformable_group,
const int height_col,
const int width_col,
scalar_t* grad_offset) {
CUDA_KERNEL_LOOP(index, n) {
scalar_t val = 0;
int w = index % width_col;
int h = (index / width_col) % height_col;
int c = (index / width_col / height_col) % offset_channels;
int b = (index / width_col / height_col) / offset_channels;
// compute the start and end of the output
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
const int col_step = kernel_h * kernel_w;
int cnt = 0;
const scalar_t* data_col_ptr = data_col +
deformable_group_index * channel_per_deformable_group * batch_size *
width_col * height_col;
const scalar_t* data_im_ptr = data_im +
(b * deformable_group + deformable_group_index) *
channel_per_deformable_group / kernel_h / kernel_w * height * width;
const scalar_t* data_offset_ptr = data_offset +
(b * deformable_group + deformable_group_index) * 2 * kernel_h *
kernel_w * height_col * width_col;
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group;
col_c += col_step) {
const int col_pos =
(((col_c * batch_size + b) * height_col) + h) * width_col + w;
const int bp_dir = offset_c % 2;
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
int i =
(col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
int w_out = col_pos % width_col;
int h_out = (col_pos / width_col) % height_col;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const int data_offset_h_ptr =
(((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
const int data_offset_w_ptr =
(((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col +
w_out);
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
scalar_t inv_h = h_in + i * dilation_h + offset_h;
scalar_t inv_w = w_in + j * dilation_w + offset_w;
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) {
inv_h = inv_w = -2;
}
const scalar_t weight = get_coordinate_weight(
inv_h,
inv_w,
height,
width,
data_im_ptr + cnt * height * width,
width,
bp_dir);
val += weight * data_col_ptr[col_pos];
cnt += 1;
}
grad_offset[index] = val;
}
}
namespace detectron2 {
void deformable_im2col(
const at::Tensor data_im,
const at::Tensor data_offset,
const int channels,
const int height,
const int width,
const int ksize_h,
const int ksize_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int parallel_imgs,
const int deformable_group,
at::Tensor data_col) {
// num_axes should be smaller than block size
// todo: check parallel_imgs is correctly passed in
int height_col =
(height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col =
(width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels = channels * height_col * width_col * parallel_imgs;
int channel_per_deformable_group = channels / deformable_group;
at::cuda::CUDAGuard device_guard(data_im.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_im.scalar_type(), "deformable_im2col_gpu", ([&] {
const scalar_t* data_im_ = data_im.data_ptr<scalar_t>();
const scalar_t* data_offset_ = data_offset.data_ptr<scalar_t>();
scalar_t* data_col_ = data_col.data_ptr<scalar_t>();
deformable_im2col_gpu_kernel<<<
GET_BLOCKS(num_kernels),
CUDA_NUM_THREADS,
0,
stream>>>(
num_kernels,
data_im_,
data_offset_,
height,
width,
ksize_h,
ksize_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
channel_per_deformable_group,
parallel_imgs,
channels,
deformable_group,
height_col,
width_col,
data_col_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
}
}
void deformable_col2im(
const at::Tensor data_col,
const at::Tensor data_offset,
const int channels,
const int height,
const int width,
const int ksize_h,
const int ksize_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int parallel_imgs,
const int deformable_group,
at::Tensor grad_im) {
// todo: make sure parallel_imgs is passed in correctly
int height_col =
(height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col =
(width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels =
channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
int channel_per_deformable_group = channels / deformable_group;
at::cuda::CUDAGuard device_guard(data_col.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
const scalar_t* data_col_ = data_col.data_ptr<scalar_t>();
const scalar_t* data_offset_ = data_offset.data_ptr<scalar_t>();
scalar_t* grad_im_ = grad_im.data_ptr<scalar_t>();
deformable_col2im_gpu_kernel<<<
GET_BLOCKS(num_kernels),
CUDA_NUM_THREADS,
0,
stream>>>(
num_kernels,
data_col_,
data_offset_,
channels,
height,
width,
ksize_h,
ksize_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
channel_per_deformable_group,
parallel_imgs,
deformable_group,
height_col,
width_col,
grad_im_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
}
}
void deformable_col2im_coord(
const at::Tensor data_col,
const at::Tensor data_im,
const at::Tensor data_offset,
const int channels,
const int height,
const int width,
const int ksize_h,
const int ksize_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int parallel_imgs,
const int deformable_group,
at::Tensor grad_offset) {
int height_col =
(height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col =
(width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w *
deformable_group * parallel_imgs;
int channel_per_deformable_group =
channels * ksize_h * ksize_w / deformable_group;
at::cuda::CUDAGuard device_guard(data_col.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
const scalar_t* data_col_ = data_col.data_ptr<scalar_t>();
const scalar_t* data_im_ = data_im.data_ptr<scalar_t>();
const scalar_t* data_offset_ = data_offset.data_ptr<scalar_t>();
scalar_t* grad_offset_ = grad_offset.data_ptr<scalar_t>();
deformable_col2im_coord_gpu_kernel<<<
GET_BLOCKS(num_kernels),
CUDA_NUM_THREADS,
0,
stream>>>(
num_kernels,
data_col_,
data_im_,
data_offset_,
channels,
height,
width,
ksize_h,
ksize_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
channel_per_deformable_group,
parallel_imgs,
2 * ksize_h * ksize_w * deformable_group,
deformable_group,
height_col,
width_col,
grad_offset_);
}));
}
} // namespace detectron2
template <typename scalar_t>
__device__ scalar_t dmcn_im2col_bilinear(
const scalar_t* bottom_data,
const int data_width,
const int height,
const int width,
scalar_t h,
scalar_t w) {
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
scalar_t lh = h - h_low;
scalar_t lw = w - w_low;
scalar_t hh = 1 - lh, hw = 1 - lw;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
v1 = bottom_data[h_low * data_width + w_low];
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = bottom_data[h_low * data_width + w_high];
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = bottom_data[h_high * data_width + w_low];
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = bottom_data[h_high * data_width + w_high];
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename scalar_t>
__device__ scalar_t dmcn_get_gradient_weight(
scalar_t argmax_h,
scalar_t argmax_w,
const int h,
const int w,
const int height,
const int width) {
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
argmax_w >= width) {
// empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
scalar_t weight = 0;
if (h == argmax_h_low && w == argmax_w_low)
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
if (h == argmax_h_low && w == argmax_w_high)
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
if (h == argmax_h_high && w == argmax_w_low)
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
if (h == argmax_h_high && w == argmax_w_high)
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
return weight;
}
template <typename scalar_t>
__device__ scalar_t dmcn_get_coordinate_weight(
scalar_t argmax_h,
scalar_t argmax_w,
const int height,
const int width,
const scalar_t* im_data,
const int data_width,
const int bp_dir) {
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
argmax_w >= width) {
// empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
scalar_t weight = 0;
if (bp_dir == 0) {
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_w_low + 1 - argmax_w) *
im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += -1 * (argmax_w - argmax_w_low) *
im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += (argmax_w_low + 1 - argmax_w) *
im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_w - argmax_w_low) *
im_data[argmax_h_high * data_width + argmax_w_high];
} else if (bp_dir == 1) {
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_h_low + 1 - argmax_h) *
im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += (argmax_h_low + 1 - argmax_h) *
im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += -1 * (argmax_h - argmax_h_low) *
im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_h - argmax_h_low) *
im_data[argmax_h_high * data_width + argmax_w_high];
}
return weight;
}
template <typename scalar_t>
__global__ void modulated_deformable_im2col_gpu_kernel(
const int n,
const scalar_t* data_im,
const scalar_t* data_offset,
const scalar_t* data_mask,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int channel_per_deformable_group,
const int batch_size,
const int num_channels,
const int deformable_group,
const int height_col,
const int width_col,
scalar_t* data_col) {
CUDA_KERNEL_LOOP(index, n) {
// index index of output matrix
const int w_col = index % width_col;
const int h_col = (index / width_col) % height_col;
const int b_col = (index / width_col / height_col) % batch_size;
const int c_im = (index / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
// compute deformable group index
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
scalar_t* data_col_ptr = data_col +
((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
// const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) *
// height + h_in) * width + w_in;
const scalar_t* data_im_ptr =
data_im + (b_col * num_channels + c_im) * height * width;
const scalar_t* data_offset_ptr = data_offset +
(b_col * deformable_group + deformable_group_index) * 2 * kernel_h *
kernel_w * height_col * width_col;
const scalar_t* data_mask_ptr = data_mask +
(b_col * deformable_group + deformable_group_index) * kernel_h *
kernel_w * height_col * width_col;
for (int i = 0; i < kernel_h; ++i) {
for (int j = 0; j < kernel_w; ++j) {
const int data_offset_h_ptr =
((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
w_col;
const int data_mask_hw_ptr =
((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
scalar_t val = static_cast<scalar_t>(0);
const scalar_t h_im = h_in + i * dilation_h + offset_h;
const scalar_t w_im = w_in + j * dilation_w + offset_w;
// if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) {
// const float map_h = i * dilation_h + offset_h;
// const float map_w = j * dilation_w + offset_w;
// const int cur_height = height - h_in;
// const int cur_width = width - w_in;
// val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height,
// cur_width, map_h, map_w);
val = dmcn_im2col_bilinear(
data_im_ptr, width, height, width, h_im, w_im);
}
*data_col_ptr = val * mask;
data_col_ptr += batch_size * height_col * width_col;
// data_col_ptr += height_col * width_col;
}
}
}
}
template <typename scalar_t>
__global__ void modulated_deformable_col2im_gpu_kernel(
const int n,
const scalar_t* data_col,
const scalar_t* data_offset,
const scalar_t* data_mask,
const int channels,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int channel_per_deformable_group,
const int batch_size,
const int deformable_group,
const int height_col,
const int width_col,
scalar_t* grad_im) {
CUDA_KERNEL_LOOP(index, n) {
const int j = (index / width_col / height_col / batch_size) % kernel_w;
const int i =
(index / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int c =
index / width_col / height_col / batch_size / kernel_w / kernel_h;
// compute the start and end of the output
const int deformable_group_index = c / channel_per_deformable_group;
int w_out = index % width_col;
int h_out = (index / width_col) % height_col;
int b = (index / width_col / height_col) % batch_size;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const scalar_t* data_offset_ptr = data_offset +
(b * deformable_group + deformable_group_index) * 2 * kernel_h *
kernel_w * height_col * width_col;
const scalar_t* data_mask_ptr = data_mask +
(b * deformable_group + deformable_group_index) * kernel_h * kernel_w *
height_col * width_col;
const int data_offset_h_ptr =
((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
const int data_offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
const int data_mask_hw_ptr =
((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
const scalar_t cur_top_grad = data_col[index] * mask;
const int cur_h = (int)cur_inv_h_data;
const int cur_w = (int)cur_inv_w_data;
for (int dy = -2; dy <= 2; dy++) {
for (int dx = -2; dx <= 2; dx++) {
if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 &&
cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
abs(cur_inv_w_data - (cur_w + dx)) < 1) {
int cur_bottom_grad_pos =
((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
scalar_t weight = dmcn_get_gradient_weight(
cur_inv_h_data,
cur_inv_w_data,
cur_h + dy,
cur_w + dx,
height,
width);
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
}
}
}
}
}
template <typename scalar_t>
__global__ void modulated_deformable_col2im_coord_gpu_kernel(
const int n,
const scalar_t* data_col,
const scalar_t* data_im,
const scalar_t* data_offset,
const scalar_t* data_mask,
const int channels,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int channel_per_deformable_group,
const int batch_size,
const int offset_channels,
const int deformable_group,
const int height_col,
const int width_col,
scalar_t* grad_offset,
scalar_t* grad_mask) {
CUDA_KERNEL_LOOP(index, n) {
scalar_t val = 0, mval = 0;
int w = index % width_col;
int h = (index / width_col) % height_col;
int c = (index / width_col / height_col) % offset_channels;
int b = (index / width_col / height_col) / offset_channels;
// compute the start and end of the output
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
const int col_step = kernel_h * kernel_w;
int cnt = 0;
const scalar_t* data_col_ptr = data_col +
deformable_group_index * channel_per_deformable_group * batch_size *
width_col * height_col;
const scalar_t* data_im_ptr = data_im +
(b * deformable_group + deformable_group_index) *
channel_per_deformable_group / kernel_h / kernel_w * height * width;
const scalar_t* data_offset_ptr = data_offset +
(b * deformable_group + deformable_group_index) * 2 * kernel_h *
kernel_w * height_col * width_col;
const scalar_t* data_mask_ptr = data_mask +
(b * deformable_group + deformable_group_index) * kernel_h * kernel_w *
height_col * width_col;
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group;
col_c += col_step) {
const int col_pos =
(((col_c * batch_size + b) * height_col) + h) * width_col + w;
const int bp_dir = offset_c % 2;
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
int i =
(col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
int w_out = col_pos % width_col;
int h_out = (col_pos / width_col) % height_col;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const int data_offset_h_ptr =
(((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
const int data_offset_w_ptr =
(((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col +
w_out);
const int data_mask_hw_ptr =
(((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
scalar_t inv_h = h_in + i * dilation_h + offset_h;
scalar_t inv_w = w_in + j * dilation_w + offset_w;
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) {
inv_h = inv_w = -2;
} else {
mval += data_col_ptr[col_pos] *
dmcn_im2col_bilinear(
data_im_ptr + cnt * height * width,
width,
height,
width,
inv_h,
inv_w);
}
const scalar_t weight = dmcn_get_coordinate_weight(
inv_h,
inv_w,
height,
width,
data_im_ptr + cnt * height * width,
width,
bp_dir);
val += weight * data_col_ptr[col_pos] * mask;
cnt += 1;
}
// KERNEL_ASSIGN(grad_offset[index], offset_req, val);
grad_offset[index] = val;
if (offset_c % 2 == 0)
// KERNEL_ASSIGN(grad_mask[(((b * deformable_group +
// deformable_group_index) * kernel_h * kernel_w + offset_c / 2) *
// height_col + h) * width_col + w], mask_req, mval);
grad_mask
[(((b * deformable_group + deformable_group_index) * kernel_h *
kernel_w +
offset_c / 2) *
height_col +
h) *
width_col +
w] = mval;
}
}
namespace detectron2 {
void modulated_deformable_im2col_cuda(
const at::Tensor data_im,
const at::Tensor data_offset,
const at::Tensor data_mask,
const int batch_size,
const int channels,
const int height_im,
const int width_im,
const int height_col,
const int width_col,
const int kernel_h,
const int kenerl_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int deformable_group,
at::Tensor data_col) {
// num_axes should be smaller than block size
const int channel_per_deformable_group = channels / deformable_group;
const int num_kernels = channels * batch_size * height_col * width_col;
at::cuda::CUDAGuard device_guard(data_im.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] {
const scalar_t* data_im_ = data_im.data_ptr<scalar_t>();
const scalar_t* data_offset_ = data_offset.data_ptr<scalar_t>();
const scalar_t* data_mask_ = data_mask.data_ptr<scalar_t>();
scalar_t* data_col_ = data_col.data_ptr<scalar_t>();
modulated_deformable_im2col_gpu_kernel<<<
GET_BLOCKS(num_kernels),
CUDA_NUM_THREADS,
0,
stream>>>(
num_kernels,
data_im_,
data_offset_,
data_mask_,
height_im,
width_im,
kernel_h,
kenerl_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
channel_per_deformable_group,
batch_size,
channels,
deformable_group,
height_col,
width_col,
data_col_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf(
"error in modulated_deformable_im2col_cuda: %s\n",
cudaGetErrorString(err));
}
}
void modulated_deformable_col2im_cuda(
const at::Tensor data_col,
const at::Tensor data_offset,
const at::Tensor data_mask,
const int batch_size,
const int channels,
const int height_im,
const int width_im,
const int height_col,
const int width_col,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int deformable_group,
at::Tensor grad_im) {
const int channel_per_deformable_group = channels / deformable_group;
const int num_kernels =
channels * kernel_h * kernel_w * batch_size * height_col * width_col;
at::cuda::CUDAGuard device_guard(data_col.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] {
const scalar_t* data_col_ = data_col.data_ptr<scalar_t>();
const scalar_t* data_offset_ = data_offset.data_ptr<scalar_t>();
const scalar_t* data_mask_ = data_mask.data_ptr<scalar_t>();
scalar_t* grad_im_ = grad_im.data_ptr<scalar_t>();
modulated_deformable_col2im_gpu_kernel<<<
GET_BLOCKS(num_kernels),
CUDA_NUM_THREADS,
0,
stream>>>(
num_kernels,
data_col_,
data_offset_,
data_mask_,
channels,
height_im,
width_im,
kernel_h,
kernel_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
channel_per_deformable_group,
batch_size,
deformable_group,
height_col,
width_col,
grad_im_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf(
"error in modulated_deformable_col2im_cuda: %s\n",
cudaGetErrorString(err));
}
}
void modulated_deformable_col2im_coord_cuda(
const at::Tensor data_col,
const at::Tensor data_im,
const at::Tensor data_offset,
const at::Tensor data_mask,
const int batch_size,
const int channels,
const int height_im,
const int width_im,
const int height_col,
const int width_col,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int deformable_group,
at::Tensor grad_offset,
at::Tensor grad_mask) {
const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h *
kernel_w * deformable_group;
const int channel_per_deformable_group =
channels * kernel_h * kernel_w / deformable_group;
at::cuda::CUDAGuard device_guard(data_col.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] {
const scalar_t* data_col_ = data_col.data_ptr<scalar_t>();
const scalar_t* data_im_ = data_im.data_ptr<scalar_t>();
const scalar_t* data_offset_ = data_offset.data_ptr<scalar_t>();
const scalar_t* data_mask_ = data_mask.data_ptr<scalar_t>();
scalar_t* grad_offset_ = grad_offset.data_ptr<scalar_t>();
scalar_t* grad_mask_ = grad_mask.data_ptr<scalar_t>();
modulated_deformable_col2im_coord_gpu_kernel<<<
GET_BLOCKS(num_kernels),
CUDA_NUM_THREADS,
0,
stream>>>(
num_kernels,
data_col_,
data_im_,
data_offset_,
data_mask_,
channels,
height_im,
width_im,
kernel_h,
kernel_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
channel_per_deformable_group,
batch_size,
2 * kernel_h * kernel_w * deformable_group,
deformable_group,
height_col,
width_col,
grad_offset_,
grad_mask_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf(
"error in modulated_deformable_col2im_coord_cuda: %s\n",
cudaGetErrorString(err));
}
}
} // namespace detectron2
// Copyright (c) Facebook, Inc. and its affiliates.
#pragma once
#include <torch/types.h>
namespace detectron2 {
at::Tensor nms_rotated_cpu(
const at::Tensor& dets,
const at::Tensor& scores,
const double iou_threshold);
#if defined(WITH_CUDA) || defined(WITH_HIP)
at::Tensor nms_rotated_cuda(
const at::Tensor& dets,
const at::Tensor& scores,
const double iou_threshold);
#endif
// Interface for Python
// inline is needed to prevent multiple function definitions when this header is
// included by different cpps
inline at::Tensor nms_rotated(
const at::Tensor& dets,
const at::Tensor& scores,
const double iou_threshold) {
assert(dets.device().is_cuda() == scores.device().is_cuda());
if (dets.device().is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
return nms_rotated_cuda(
dets.contiguous(), scores.contiguous(), iou_threshold);
#else
AT_ERROR("Detectron2 is not compiled with GPU support!");
#endif
}
return nms_rotated_cpu(dets.contiguous(), scores.contiguous(), iou_threshold);
}
} // namespace detectron2
// Copyright (c) Facebook, Inc. and its affiliates.
#include "../box_iou_rotated/box_iou_rotated_utils.h"
#include "nms_rotated.h"
namespace detectron2 {
template <typename scalar_t>
at::Tensor nms_rotated_cpu_kernel(
const at::Tensor& dets,
const at::Tensor& scores,
const double iou_threshold) {
// nms_rotated_cpu_kernel is modified from torchvision's nms_cpu_kernel,
// however, the code in this function is much shorter because
// we delegate the IoU computation for rotated boxes to
// the single_box_iou_rotated function in box_iou_rotated_utils.h
AT_ASSERTM(dets.device().is_cpu(), "dets must be a CPU tensor");
AT_ASSERTM(scores.device().is_cpu(), "scores must be a CPU tensor");
AT_ASSERTM(
dets.scalar_type() == scores.scalar_type(),
"dets should have the same type as scores");
if (dets.numel() == 0) {
return at::empty({0}, dets.options().dtype(at::kLong));
}
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
auto ndets = dets.size(0);
at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte));
at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong));
auto suppressed = suppressed_t.data_ptr<uint8_t>();
auto keep = keep_t.data_ptr<int64_t>();
auto order = order_t.data_ptr<int64_t>();
int64_t num_to_keep = 0;
for (int64_t _i = 0; _i < ndets; _i++) {
auto i = order[_i];
if (suppressed[i] == 1) {
continue;
}
keep[num_to_keep++] = i;
for (int64_t _j = _i + 1; _j < ndets; _j++) {
auto j = order[_j];
if (suppressed[j] == 1) {
continue;
}
auto ovr = single_box_iou_rotated<scalar_t>(
dets[i].data_ptr<scalar_t>(), dets[j].data_ptr<scalar_t>());
if (ovr >= iou_threshold) {
suppressed[j] = 1;
}
}
}
return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep);
}
at::Tensor nms_rotated_cpu(
// input must be contiguous
const at::Tensor& dets,
const at::Tensor& scores,
const double iou_threshold) {
auto result = at::empty({0}, dets.options());
AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_rotated", [&] {
result = nms_rotated_cpu_kernel<scalar_t>(dets, scores, iou_threshold);
});
return result;
}
} // namespace detectron2
// Copyright (c) Facebook, Inc. and its affiliates.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#ifdef WITH_CUDA
#include "../box_iou_rotated/box_iou_rotated_utils.h"
#endif
// TODO avoid this when pytorch supports "same directory" hipification
#ifdef WITH_HIP
#include "box_iou_rotated/box_iou_rotated_utils.h"
#endif
using namespace detectron2;
namespace {
int const threadsPerBlock = sizeof(unsigned long long) * 8;
}
template <typename T>
__global__ void nms_rotated_cuda_kernel(
const int n_boxes,
const double iou_threshold,
const T* dev_boxes,
unsigned long long* dev_mask) {
// nms_rotated_cuda_kernel is modified from torchvision's nms_cuda_kernel
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
// if (row_start > col_start) return;
const int row_size =
min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
const int col_size =
min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
// Compared to nms_cuda_kernel, where each box is represented with 4 values
// (x1, y1, x2, y2), each rotated box is represented with 5 values
// (x_center, y_center, width, height, angle_degrees) here.
__shared__ T block_boxes[threadsPerBlock * 5];
if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 5 + 0] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0];
block_boxes[threadIdx.x * 5 + 1] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1];
block_boxes[threadIdx.x * 5 + 2] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2];
block_boxes[threadIdx.x * 5 + 3] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3];
block_boxes[threadIdx.x * 5 + 4] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4];
}
__syncthreads();
if (threadIdx.x < row_size) {
const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
const T* cur_box = dev_boxes + cur_box_idx * 5;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
// Instead of devIoU used by original horizontal nms, here
// we use the single_box_iou_rotated function from box_iou_rotated_utils.h
if (single_box_iou_rotated<T>(cur_box, block_boxes + i * 5) >
iou_threshold) {
t |= 1ULL << i;
}
}
const int col_blocks = at::cuda::ATenCeilDiv(n_boxes, threadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
}
namespace detectron2 {
at::Tensor nms_rotated_cuda(
// input must be contiguous
const at::Tensor& dets,
const at::Tensor& scores,
double iou_threshold) {
// using scalar_t = float;
AT_ASSERTM(dets.is_cuda(), "dets must be a CUDA tensor");
AT_ASSERTM(scores.is_cuda(), "scores must be a CUDA tensor");
at::cuda::CUDAGuard device_guard(dets.device());
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
auto dets_sorted = dets.index_select(0, order_t);
auto dets_num = dets.size(0);
const int col_blocks =
at::cuda::ATenCeilDiv(static_cast<int>(dets_num), threadsPerBlock);
at::Tensor mask =
at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong));
dim3 blocks(col_blocks, col_blocks);
dim3 threads(threadsPerBlock);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(
dets_sorted.scalar_type(), "nms_rotated_kernel_cuda", [&] {
nms_rotated_cuda_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
dets_num,
iou_threshold,
dets_sorted.data_ptr<scalar_t>(),
(unsigned long long*)mask.data_ptr<int64_t>());
});
at::Tensor mask_cpu = mask.to(at::kCPU);
unsigned long long* mask_host =
(unsigned long long*)mask_cpu.data_ptr<int64_t>();
std::vector<unsigned long long> remv(col_blocks);
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
at::Tensor keep =
at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));
int64_t* keep_out = keep.data_ptr<int64_t>();
int num_to_keep = 0;
for (int i = 0; i < dets_num; i++) {
int nblock = i / threadsPerBlock;
int inblock = i % threadsPerBlock;
if (!(remv[nblock] & (1ULL << inblock))) {
keep_out[num_to_keep++] = i;
unsigned long long* p = mask_host + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
}
}
AT_CUDA_CHECK(cudaGetLastError());
return order_t.index(
{keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)
.to(order_t.device(), keep.scalar_type())});
}
} // namespace detectron2
// Copyright (c) Facebook, Inc. and its affiliates.
#include <torch/extension.h>
#include "ROIAlignRotated/ROIAlignRotated.h"
#include "box_iou_rotated/box_iou_rotated.h"
#include "cocoeval/cocoeval.h"
#include "deformable/deform_conv.h"
#include "nms_rotated/nms_rotated.h"
namespace detectron2 {
#if defined(WITH_CUDA) || defined(WITH_HIP)
extern int get_cudart_version();
#endif
std::string get_cuda_version() {
#if defined(WITH_CUDA) || defined(WITH_HIP)
std::ostringstream oss;
#if defined(WITH_CUDA)
oss << "CUDA ";
#else
oss << "HIP ";
#endif
// copied from
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231
auto printCudaStyleVersion = [&](int v) {
oss << (v / 1000) << "." << (v / 10 % 100);
if (v % 10 != 0) {
oss << "." << (v % 10);
}
};
printCudaStyleVersion(get_cudart_version());
return oss.str();
#else // neither CUDA nor HIP
return std::string("not available");
#endif
}
bool has_cuda() {
#if defined(WITH_CUDA)
return true;
#else
return false;
#endif
}
// similar to
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Version.cpp
std::string get_compiler_version() {
std::ostringstream ss;
#if defined(__GNUC__)
#ifndef __clang__
#if ((__GNUC__ <= 4) && (__GNUC_MINOR__ <= 8))
#error "GCC >= 4.9 is required!"
#endif
{ ss << "GCC " << __GNUC__ << "." << __GNUC_MINOR__; }
#endif
#endif
#if defined(__clang_major__)
{
ss << "clang " << __clang_major__ << "." << __clang_minor__ << "."
<< __clang_patchlevel__;
}
#endif
#if defined(_MSC_VER)
{ ss << "MSVC " << _MSC_FULL_VER; }
#endif
return ss.str();
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("get_compiler_version", &get_compiler_version, "get_compiler_version");
m.def("get_cuda_version", &get_cuda_version, "get_cuda_version");
m.def("has_cuda", &has_cuda, "has_cuda");
m.def("box_iou_rotated", &box_iou_rotated, "IoU for rotated boxes");
m.def("deform_conv_forward", &deform_conv_forward, "deform_conv_forward");
m.def(
"deform_conv_backward_input",
&deform_conv_backward_input,
"deform_conv_backward_input");
m.def(
"deform_conv_backward_filter",
&deform_conv_backward_filter,
"deform_conv_backward_filter");
m.def(
"modulated_deform_conv_forward",
&modulated_deform_conv_forward,
"modulated_deform_conv_forward");
m.def(
"modulated_deform_conv_backward",
&modulated_deform_conv_backward,
"modulated_deform_conv_backward");
m.def("nms_rotated", &nms_rotated, "NMS for rotated boxes");
m.def(
"roi_align_rotated_forward",
&ROIAlignRotated_forward,
"Forward pass for Rotated ROI-Align Operator");
m.def(
"roi_align_rotated_backward",
&ROIAlignRotated_backward,
"Backward pass for Rotated ROI-Align Operator");
m.def("COCOevalAccumulate", &COCOeval::Accumulate, "COCOeval::Accumulate");
m.def(
"COCOevalEvaluateImages",
&COCOeval::EvaluateImages,
"COCOeval::EvaluateImages");
pybind11::class_<COCOeval::InstanceAnnotation>(m, "InstanceAnnotation")
.def(pybind11::init<uint64_t, double, double, bool, bool>());
pybind11::class_<COCOeval::ImageEvaluation>(m, "ImageEvaluation")
.def(pybind11::init<>());
}
#ifdef TORCH_LIBRARY
TORCH_LIBRARY(detectron2, m) {
m.def("nms_rotated", &nms_rotated);
}
#endif
} // namespace detectron2
# Copyright (c) Facebook, Inc. and its affiliates.
import math
from functools import lru_cache
import torch
from torch import nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from torchvision.ops import deform_conv2d
from detectron2 import _C
from .wrappers import _NewEmptyTensorOp
class _DeformConv(Function):
@staticmethod
def forward(
ctx,
input,
offset,
weight,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
im2col_step=64,
):
if input is not None and input.dim() != 4:
raise ValueError(
"Expected 4D tensor as input, got {}D tensor instead.".format(input.dim())
)
ctx.stride = _pair(stride)
ctx.padding = _pair(padding)
ctx.dilation = _pair(dilation)
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.im2col_step = im2col_step
ctx.save_for_backward(input, offset, weight)
output = input.new_empty(
_DeformConv._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride)
)
ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
if not input.is_cuda:
if deformable_groups != 1:
raise NotImplementedError(
"Deformable Conv with deformable_groups != 1 is not supported on CPUs!"
)
return deform_conv2d(
input, offset, weight, stride=stride, padding=padding, dilation=dilation
)
else:
cur_im2col_step = _DeformConv._cal_im2col_step(input.shape[0], ctx.im2col_step)
assert (input.shape[0] % cur_im2col_step) == 0, "im2col step must divide batchsize"
_C.deform_conv_forward(
input,
weight,
offset,
output,
ctx.bufs_[0],
ctx.bufs_[1],
weight.size(3),
weight.size(2),
ctx.stride[1],
ctx.stride[0],
ctx.padding[1],
ctx.padding[0],
ctx.dilation[1],
ctx.dilation[0],
ctx.groups,
ctx.deformable_groups,
cur_im2col_step,
)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
input, offset, weight = ctx.saved_tensors
grad_input = grad_offset = grad_weight = None
if not grad_output.is_cuda:
raise NotImplementedError("Deformable Conv is not supported on CPUs!")
else:
cur_im2col_step = _DeformConv._cal_im2col_step(input.shape[0], ctx.im2col_step)
assert (input.shape[0] % cur_im2col_step) == 0, "im2col step must divide batchsize"
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
_C.deform_conv_backward_input(
input,
offset,
grad_output,
grad_input,
grad_offset,
weight,
ctx.bufs_[0],
weight.size(3),
weight.size(2),
ctx.stride[1],
ctx.stride[0],
ctx.padding[1],
ctx.padding[0],
ctx.dilation[1],
ctx.dilation[0],
ctx.groups,
ctx.deformable_groups,
cur_im2col_step,
)
if ctx.needs_input_grad[2]:
grad_weight = torch.zeros_like(weight)
_C.deform_conv_backward_filter(
input,
offset,
grad_output,
grad_weight,
ctx.bufs_[0],
ctx.bufs_[1],
weight.size(3),
weight.size(2),
ctx.stride[1],
ctx.stride[0],
ctx.padding[1],
ctx.padding[0],
ctx.dilation[1],
ctx.dilation[0],
ctx.groups,
ctx.deformable_groups,
1,
cur_im2col_step,
)
return grad_input, grad_offset, grad_weight, None, None, None, None, None, None
@staticmethod
def _output_size(input, weight, padding, dilation, stride):
channels = weight.size(0)
output_size = (input.size(0), channels)
for d in range(input.dim() - 2):
in_size = input.size(d + 2)
pad = padding[d]
kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
stride_ = stride[d]
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1,)
if not all(map(lambda s: s > 0, output_size)):
raise ValueError(
"convolution input is too small (output would be {})".format(
"x".join(map(str, output_size))
)
)
return output_size
@staticmethod
@lru_cache(maxsize=128)
def _cal_im2col_step(input_size, default_size):
"""
Calculate proper im2col step size, which should be divisible by input_size and not larger
than prefer_size. Meanwhile the step size should be as large as possible to be more
efficient. So we choose the largest one among all divisors of input_size which are smaller
than prefer_size.
:param input_size: input batch size .
:param default_size: default preferred im2col step size.
:return: the largest proper step size.
"""
if input_size <= default_size:
return input_size
best_step = 1
for step in range(2, min(int(math.sqrt(input_size)) + 1, default_size)):
if input_size % step == 0:
if input_size // step <= default_size:
return input_size // step
best_step = step
return best_step
class _ModulatedDeformConv(Function):
@staticmethod
def forward(
ctx,
input,
offset,
mask,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
):
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.with_bias = bias is not None
if not ctx.with_bias:
bias = input.new_empty(1) # fake tensor
if not input.is_cuda:
raise NotImplementedError("Deformable Conv is not supported on CPUs!")
if (
weight.requires_grad
or mask.requires_grad
or offset.requires_grad
or input.requires_grad
):
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty(_ModulatedDeformConv._infer_shape(ctx, input, weight))
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
_C.modulated_deform_conv_forward(
input,
weight,
bias,
ctx._bufs[0],
offset,
mask,
output,
ctx._bufs[1],
weight.shape[2],
weight.shape[3],
ctx.stride,
ctx.stride,
ctx.padding,
ctx.padding,
ctx.dilation,
ctx.dilation,
ctx.groups,
ctx.deformable_groups,
ctx.with_bias,
)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError("Deformable Conv is not supported on CPUs!")
input, offset, mask, weight, bias = ctx.saved_tensors
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
grad_mask = torch.zeros_like(mask)
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias)
_C.modulated_deform_conv_backward(
input,
weight,
bias,
ctx._bufs[0],
offset,
mask,
ctx._bufs[1],
grad_input,
grad_weight,
grad_bias,
grad_offset,
grad_mask,
grad_output,
weight.shape[2],
weight.shape[3],
ctx.stride,
ctx.stride,
ctx.padding,
ctx.padding,
ctx.dilation,
ctx.dilation,
ctx.groups,
ctx.deformable_groups,
ctx.with_bias,
)
if not ctx.with_bias:
grad_bias = None
return (
grad_input,
grad_offset,
grad_mask,
grad_weight,
grad_bias,
None,
None,
None,
None,
None,
)
@staticmethod
def _infer_shape(ctx, input, weight):
n = input.size(0)
channels_out = weight.size(0)
height, width = input.shape[2:4]
kernel_h, kernel_w = weight.shape[2:4]
height_out = (
height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)
) // ctx.stride + 1
width_out = (
width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)
) // ctx.stride + 1
return n, channels_out, height_out, width_out
deform_conv = _DeformConv.apply
modulated_deform_conv = _ModulatedDeformConv.apply
class DeformConv(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=False,
norm=None,
activation=None,
):
"""
Deformable convolution from :paper:`deformconv`.
Arguments are similar to :class:`Conv2D`. Extra arguments:
Args:
deformable_groups (int): number of groups used in deformable convolution.
norm (nn.Module, optional): a normalization layer
activation (callable(Tensor) -> Tensor): a callable activation function
"""
super(DeformConv, self).__init__()
assert not bias
assert in_channels % groups == 0, "in_channels {} cannot be divisible by groups {}".format(
in_channels, groups
)
assert (
out_channels % groups == 0
), "out_channels {} cannot be divisible by groups {}".format(out_channels, groups)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.deformable_groups = deformable_groups
self.norm = norm
self.activation = activation
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size)
)
self.bias = None
nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
def forward(self, x, offset):
if x.numel() == 0:
# When input is empty, we want to return a empty tensor with "correct" shape,
# So that the following operations will not panic
# if they check for the shape of the tensor.
# This computes the height and width of the output tensor
output_shape = [
(i + 2 * p - (di * (k - 1) + 1)) // s + 1
for i, p, di, k, s in zip(
x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride
)
]
output_shape = [x.shape[0], self.weight.shape[0]] + output_shape
return _NewEmptyTensorOp.apply(x, output_shape)
x = deform_conv(
x,
offset,
self.weight,
self.stride,
self.padding,
self.dilation,
self.groups,
self.deformable_groups,
)
if self.norm is not None:
x = self.norm(x)
if self.activation is not None:
x = self.activation(x)
return x
def extra_repr(self):
tmpstr = "in_channels=" + str(self.in_channels)
tmpstr += ", out_channels=" + str(self.out_channels)
tmpstr += ", kernel_size=" + str(self.kernel_size)
tmpstr += ", stride=" + str(self.stride)
tmpstr += ", padding=" + str(self.padding)
tmpstr += ", dilation=" + str(self.dilation)
tmpstr += ", groups=" + str(self.groups)
tmpstr += ", deformable_groups=" + str(self.deformable_groups)
tmpstr += ", bias=False"
return tmpstr
class ModulatedDeformConv(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=True,
norm=None,
activation=None,
):
"""
Modulated deformable convolution from :paper:`deformconv2`.
Arguments are similar to :class:`Conv2D`. Extra arguments:
Args:
deformable_groups (int): number of groups used in deformable convolution.
norm (nn.Module, optional): a normalization layer
activation (callable(Tensor) -> Tensor): a callable activation function
"""
super(ModulatedDeformConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.deformable_groups = deformable_groups
self.with_bias = bias
self.norm = norm
self.activation = activation
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
)
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.bias = None
nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
if self.bias is not None:
nn.init.constant_(self.bias, 0)
def forward(self, x, offset, mask):
if x.numel() == 0:
output_shape = [
(i + 2 * p - (di * (k - 1) + 1)) // s + 1
for i, p, di, k, s in zip(
x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride
)
]
output_shape = [x.shape[0], self.weight.shape[0]] + output_shape
return _NewEmptyTensorOp.apply(x, output_shape)
x = modulated_deform_conv(
x,
offset,
mask,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups,
self.deformable_groups,
)
if self.norm is not None:
x = self.norm(x)
if self.activation is not None:
x = self.activation(x)
return x
def extra_repr(self):
tmpstr = "in_channels=" + str(self.in_channels)
tmpstr += ", out_channels=" + str(self.out_channels)
tmpstr += ", kernel_size=" + str(self.kernel_size)
tmpstr += ", stride=" + str(self.stride)
tmpstr += ", padding=" + str(self.padding)
tmpstr += ", dilation=" + str(self.dilation)
tmpstr += ", groups=" + str(self.groups)
tmpstr += ", deformable_groups=" + str(self.deformable_groups)
tmpstr += ", bias=" + str(self.with_bias)
return tmpstr
import math
import torch
def diou_loss(
boxes1: torch.Tensor,
boxes2: torch.Tensor,
reduction: str = "none",
eps: float = 1e-7,
) -> torch.Tensor:
"""
Distance Intersection over Union Loss (Zhaohui Zheng et. al)
https://arxiv.org/abs/1911.08287
Args:
boxes1, boxes2 (Tensor): box locations in XYXY format, shape (N, 4) or (4,).
reduction: 'none' | 'mean' | 'sum'
'none': No reduction will be applied to the output.
'mean': The output will be averaged.
'sum': The output will be summed.
eps (float): small number to prevent division by zero
"""
x1, y1, x2, y2 = boxes1.unbind(dim=-1)
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
# TODO: use torch._assert_async() when pytorch 1.8 support is dropped
assert (x2 >= x1).all(), "bad box: x1 larger than x2"
assert (y2 >= y1).all(), "bad box: y1 larger than y2"
# Intersection keypoints
xkis1 = torch.max(x1, x1g)
ykis1 = torch.max(y1, y1g)
xkis2 = torch.min(x2, x2g)
ykis2 = torch.min(y2, y2g)
intsct = torch.zeros_like(x1)
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps
iou = intsct / union
# smallest enclosing box
xc1 = torch.min(x1, x1g)
yc1 = torch.min(y1, y1g)
xc2 = torch.max(x2, x2g)
yc2 = torch.max(y2, y2g)
diag_len = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps
# centers of boxes
x_p = (x2 + x1) / 2
y_p = (y2 + y1) / 2
x_g = (x1g + x2g) / 2
y_g = (y1g + y2g) / 2
distance = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2)
# Eqn. (7)
loss = 1 - iou + (distance / diag_len)
if reduction == "mean":
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
elif reduction == "sum":
loss = loss.sum()
return loss
def ciou_loss(
boxes1: torch.Tensor,
boxes2: torch.Tensor,
reduction: str = "none",
eps: float = 1e-7,
) -> torch.Tensor:
"""
Complete Intersection over Union Loss (Zhaohui Zheng et. al)
https://arxiv.org/abs/1911.08287
Args:
boxes1, boxes2 (Tensor): box locations in XYXY format, shape (N, 4) or (4,).
reduction: 'none' | 'mean' | 'sum'
'none': No reduction will be applied to the output.
'mean': The output will be averaged.
'sum': The output will be summed.
eps (float): small number to prevent division by zero
"""
x1, y1, x2, y2 = boxes1.unbind(dim=-1)
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
# TODO: use torch._assert_async() when pytorch 1.8 support is dropped
assert (x2 >= x1).all(), "bad box: x1 larger than x2"
assert (y2 >= y1).all(), "bad box: y1 larger than y2"
# Intersection keypoints
xkis1 = torch.max(x1, x1g)
ykis1 = torch.max(y1, y1g)
xkis2 = torch.min(x2, x2g)
ykis2 = torch.min(y2, y2g)
intsct = torch.zeros_like(x1)
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps
iou = intsct / union
# smallest enclosing box
xc1 = torch.min(x1, x1g)
yc1 = torch.min(y1, y1g)
xc2 = torch.max(x2, x2g)
yc2 = torch.max(y2, y2g)
diag_len = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps
# centers of boxes
x_p = (x2 + x1) / 2
y_p = (y2 + y1) / 2
x_g = (x1g + x2g) / 2
y_g = (y1g + y2g) / 2
distance = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2)
# width and height of boxes
w_pred = x2 - x1
h_pred = y2 - y1
w_gt = x2g - x1g
h_gt = y2g - y1g
v = (4 / (math.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2)
with torch.no_grad():
alpha = v / (1 - iou + v + eps)
# Eqn. (10)
loss = 1 - iou + (distance / diag_len) + alpha * v
if reduction == "mean":
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
elif reduction == "sum":
loss = loss.sum()
return loss
# Copyright (c) Facebook, Inc. and its affiliates.
import numpy as np
from typing import Tuple
import torch
from PIL import Image
from torch.nn import functional as F
from detectron2.structures import Boxes
__all__ = ["paste_masks_in_image"]
BYTES_PER_FLOAT = 4
# TODO: This memory limit may be too much or too little. It would be better to
# determine it based on available resources.
GPU_MEM_LIMIT = 1024 ** 3 # 1 GB memory limit
def _do_paste_mask(masks, boxes, img_h: int, img_w: int, skip_empty: bool = True):
"""
Args:
masks: N, 1, H, W
boxes: N, 4
img_h, img_w (int):
skip_empty (bool): only paste masks within the region that
tightly bound all boxes, and returns the results this region only.
An important optimization for CPU.
Returns:
if skip_empty == False, a mask of shape (N, img_h, img_w)
if skip_empty == True, a mask of shape (N, h', w'), and the slice
object for the corresponding region.
"""
# On GPU, paste all masks together (up to chunk size)
# by using the entire image to sample the masks
# Compared to pasting them one by one,
# this has more operations but is faster on COCO-scale dataset.
device = masks.device
if skip_empty and not torch.jit.is_scripting():
x0_int, y0_int = torch.clamp(boxes.min(dim=0).values.floor()[:2] - 1, min=0).to(
dtype=torch.int32
)
x1_int = torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32)
y1_int = torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32)
else:
x0_int, y0_int = 0, 0
x1_int, y1_int = img_w, img_h
x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1
N = masks.shape[0]
img_y = torch.arange(y0_int, y1_int, device=device, dtype=torch.float32) + 0.5
img_x = torch.arange(x0_int, x1_int, device=device, dtype=torch.float32) + 0.5
img_y = (img_y - y0) / (y1 - y0) * 2 - 1
img_x = (img_x - x0) / (x1 - x0) * 2 - 1
# img_x, img_y have shapes (N, w), (N, h)
gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1))
gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
grid = torch.stack([gx, gy], dim=3)
if not torch.jit.is_scripting():
if not masks.dtype.is_floating_point:
masks = masks.float()
img_masks = F.grid_sample(masks, grid.to(masks.dtype), align_corners=False)
if skip_empty and not torch.jit.is_scripting():
return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int))
else:
return img_masks[:, 0], ()
def paste_masks_in_image(
masks: torch.Tensor, boxes: Boxes, image_shape: Tuple[int, int], threshold: float = 0.5
):
"""
Paste a set of masks that are of a fixed resolution (e.g., 28 x 28) into an image.
The location, height, and width for pasting each mask is determined by their
corresponding bounding boxes in boxes.
Note:
This is a complicated but more accurate implementation. In actual deployment, it is
often enough to use a faster but less accurate implementation.
See :func:`paste_mask_in_image_old` in this file for an alternative implementation.
Args:
masks (tensor): Tensor of shape (Bimg, Hmask, Wmask), where Bimg is the number of
detected object instances in the image and Hmask, Wmask are the mask width and mask
height of the predicted mask (e.g., Hmask = Wmask = 28). Values are in [0, 1].
boxes (Boxes or Tensor): A Boxes of length Bimg or Tensor of shape (Bimg, 4).
boxes[i] and masks[i] correspond to the same object instance.
image_shape (tuple): height, width
threshold (float): A threshold in [0, 1] for converting the (soft) masks to
binary masks.
Returns:
img_masks (Tensor): A tensor of shape (Bimg, Himage, Wimage), where Bimg is the
number of detected object instances and Himage, Wimage are the image width
and height. img_masks[i] is a binary mask for object instance i.
"""
assert masks.shape[-1] == masks.shape[-2], "Only square mask predictions are supported"
N = len(masks)
if N == 0:
return masks.new_empty((0,) + image_shape, dtype=torch.uint8)
if not isinstance(boxes, torch.Tensor):
boxes = boxes.tensor
device = boxes.device
assert len(boxes) == N, boxes.shape
img_h, img_w = image_shape
# The actual implementation split the input into chunks,
# and paste them chunk by chunk.
if device.type == "cpu" or torch.jit.is_scripting():
# CPU is most efficient when they are pasted one by one with skip_empty=True
# so that it performs minimal number of operations.
num_chunks = N
else:
# GPU benefits from parallelism for larger chunks, but may have memory issue
# int(img_h) because shape may be tensors in tracing
num_chunks = int(np.ceil(N * int(img_h) * int(img_w) * BYTES_PER_FLOAT / GPU_MEM_LIMIT))
assert (
num_chunks <= N
), "Default GPU_MEM_LIMIT in mask_ops.py is too small; try increasing it"
chunks = torch.chunk(torch.arange(N, device=device), num_chunks)
img_masks = torch.zeros(
N, img_h, img_w, device=device, dtype=torch.bool if threshold >= 0 else torch.uint8
)
for inds in chunks:
masks_chunk, spatial_inds = _do_paste_mask(
masks[inds, None, :, :], boxes[inds], img_h, img_w, skip_empty=device.type == "cpu"
)
if threshold >= 0:
masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool)
else:
# for visualization and debugging
masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8)
if torch.jit.is_scripting(): # Scripting does not use the optimized codepath
img_masks[inds] = masks_chunk
else:
img_masks[(inds,) + spatial_inds] = masks_chunk
return img_masks
# The below are the original paste function (from Detectron1) which has
# larger quantization error.
# It is faster on CPU, while the aligned one is faster on GPU thanks to grid_sample.
def paste_mask_in_image_old(mask, box, img_h, img_w, threshold):
"""
Paste a single mask in an image.
This is a per-box implementation of :func:`paste_masks_in_image`.
This function has larger quantization error due to incorrect pixel
modeling and is not used any more.
Args:
mask (Tensor): A tensor of shape (Hmask, Wmask) storing the mask of a single
object instance. Values are in [0, 1].
box (Tensor): A tensor of shape (4, ) storing the x0, y0, x1, y1 box corners
of the object instance.
img_h, img_w (int): Image height and width.
threshold (float): Mask binarization threshold in [0, 1].
Returns:
im_mask (Tensor):
The resized and binarized object mask pasted into the original
image plane (a tensor of shape (img_h, img_w)).
"""
# Conversion from continuous box coordinates to discrete pixel coordinates
# via truncation (cast to int32). This determines which pixels to paste the
# mask onto.
box = box.to(dtype=torch.int32) # Continuous to discrete coordinate conversion
# An example (1D) box with continuous coordinates (x0=0.7, x1=4.3) will map to
# a discrete coordinates (x0=0, x1=4). Note that box is mapped to 5 = x1 - x0 + 1
# pixels (not x1 - x0 pixels).
samples_w = box[2] - box[0] + 1 # Number of pixel samples, *not* geometric width
samples_h = box[3] - box[1] + 1 # Number of pixel samples, *not* geometric height
# Resample the mask from it's original grid to the new samples_w x samples_h grid
mask = Image.fromarray(mask.cpu().numpy())
mask = mask.resize((samples_w, samples_h), resample=Image.BILINEAR)
mask = np.array(mask, copy=False)
if threshold >= 0:
mask = np.array(mask > threshold, dtype=np.uint8)
mask = torch.from_numpy(mask)
else:
# for visualization and debugging, we also
# allow it to return an unmodified mask
mask = torch.from_numpy(mask * 255).to(torch.uint8)
im_mask = torch.zeros((img_h, img_w), dtype=torch.uint8)
x_0 = max(box[0], 0)
x_1 = min(box[2] + 1, img_w)
y_0 = max(box[1], 0)
y_1 = min(box[3] + 1, img_h)
im_mask[y_0:y_1, x_0:x_1] = mask[
(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])
]
return im_mask
# Our pixel modeling requires extrapolation for any continuous
# coordinate < 0.5 or > length - 0.5. When sampling pixels on the masks,
# we would like this extrapolation to be an interpolation between boundary values and zero,
# instead of using absolute zero or boundary values.
# Therefore `paste_mask_in_image_old` is often used with zero padding around the masks like this:
# masks, scale = pad_masks(masks[:, 0, :, :], 1)
# boxes = scale_boxes(boxes.tensor, scale)
def pad_masks(masks, padding):
"""
Args:
masks (tensor): A tensor of shape (B, M, M) representing B masks.
padding (int): Number of cells to pad on all sides.
Returns:
The padded masks and the scale factor of the padding size / original size.
"""
B = masks.shape[0]
M = masks.shape[-1]
pad2 = 2 * padding
scale = float(M + pad2) / M
padded_masks = masks.new_zeros((B, M + pad2, M + pad2))
padded_masks[:, padding:-padding, padding:-padding] = masks
return padded_masks, scale
def scale_boxes(boxes, scale):
"""
Args:
boxes (tensor): A tensor of shape (B, 4) representing B boxes with 4
coords representing the corners x0, y0, x1, y1,
scale (float): The box scaling factor.
Returns:
Scaled boxes.
"""
w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
w_half *= scale
h_half *= scale
scaled_boxes = torch.zeros_like(boxes)
scaled_boxes[:, 0] = x_c - w_half
scaled_boxes[:, 2] = x_c + w_half
scaled_boxes[:, 1] = y_c - h_half
scaled_boxes[:, 3] = y_c + h_half
return scaled_boxes
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.
from typing import List
import torch
from torchvision.ops import boxes as box_ops
from torchvision.ops import nms # BC-compat
def batched_nms(
boxes: torch.Tensor, scores: torch.Tensor, idxs: torch.Tensor, iou_threshold: float
):
"""
Same as torchvision.ops.boxes.batched_nms, but safer.
"""
assert boxes.shape[-1] == 4
# TODO may need better strategy.
# Investigate after having a fully-cuda NMS op.
if len(boxes) < 40000:
# fp16 does not have enough range for batched NMS
return box_ops.batched_nms(boxes.float(), scores, idxs, iou_threshold)
result_mask = scores.new_zeros(scores.size(), dtype=torch.bool)
for id in torch.jit.annotate(List[int], torch.unique(idxs).cpu().tolist()):
mask = (idxs == id).nonzero().view(-1)
keep = nms(boxes[mask], scores[mask], iou_threshold)
result_mask[mask[keep]] = True
keep = result_mask.nonzero().view(-1)
keep = keep[scores[keep].argsort(descending=True)]
return keep
# Note: this function (nms_rotated) might be moved into
# torchvision/ops/boxes.py in the future
def nms_rotated(boxes, scores, iou_threshold):
"""
Performs non-maximum suppression (NMS) on the rotated boxes according
to their intersection-over-union (IoU).
Rotated NMS iteratively removes lower scoring rotated boxes which have an
IoU greater than iou_threshold with another (higher scoring) rotated box.
Note that RotatedBox (5, 3, 4, 2, -90) covers exactly the same region as
RotatedBox (5, 3, 4, 2, 90) does, and their IoU will be 1. However, they
can be representing completely different objects in certain tasks, e.g., OCR.
As for the question of whether rotated-NMS should treat them as faraway boxes
even though their IOU is 1, it depends on the application and/or ground truth annotation.
As an extreme example, consider a single character v and the square box around it.
If the angle is 0 degree, the object (text) would be read as 'v';
If the angle is 90 degrees, the object (text) would become '>';
If the angle is 180 degrees, the object (text) would become '^';
If the angle is 270/-90 degrees, the object (text) would become '<'
All of these cases have IoU of 1 to each other, and rotated NMS that only
uses IoU as criterion would only keep one of them with the highest score -
which, practically, still makes sense in most cases because typically
only one of theses orientations is the correct one. Also, it does not matter
as much if the box is only used to classify the object (instead of transcribing
them with a sequential OCR recognition model) later.
On the other hand, when we use IoU to filter proposals that are close to the
ground truth during training, we should definitely take the angle into account if
we know the ground truth is labeled with the strictly correct orientation (as in,
upside-down words are annotated with -180 degrees even though they can be covered
with a 0/90/-90 degree box, etc.)
The way the original dataset is annotated also matters. For example, if the dataset
is a 4-point polygon dataset that does not enforce ordering of vertices/orientation,
we can estimate a minimum rotated bounding box to this polygon, but there's no way
we can tell the correct angle with 100% confidence (as shown above, there could be 4 different
rotated boxes, with angles differed by 90 degrees to each other, covering the exactly
same region). In that case we have to just use IoU to determine the box
proximity (as many detection benchmarks (even for text) do) unless there're other
assumptions we can make (like width is always larger than height, or the object is not
rotated by more than 90 degrees CCW/CW, etc.)
In summary, not considering angles in rotated NMS seems to be a good option for now,
but we should be aware of its implications.
Args:
boxes (Tensor[N, 5]): Rotated boxes to perform NMS on. They are expected to be in
(x_center, y_center, width, height, angle_degrees) format.
scores (Tensor[N]): Scores for each one of the rotated boxes
iou_threshold (float): Discards all overlapping rotated boxes with IoU < iou_threshold
Returns:
keep (Tensor): int64 tensor with the indices of the elements that have been kept
by Rotated NMS, sorted in decreasing order of scores
"""
return torch.ops.detectron2.nms_rotated(boxes, scores, iou_threshold)
# Note: this function (batched_nms_rotated) might be moved into
# torchvision/ops/boxes.py in the future
def batched_nms_rotated(boxes, scores, idxs, iou_threshold):
"""
Performs non-maximum suppression in a batched fashion.
Each index value correspond to a category, and NMS
will not be applied between elements of different categories.
Args:
boxes (Tensor[N, 5]):
boxes where NMS will be performed. They
are expected to be in (x_ctr, y_ctr, width, height, angle_degrees) format
scores (Tensor[N]):
scores for each one of the boxes
idxs (Tensor[N]):
indices of the categories for each one of the boxes.
iou_threshold (float):
discards all overlapping boxes
with IoU < iou_threshold
Returns:
Tensor:
int64 tensor with the indices of the elements that have been kept
by NMS, sorted in decreasing order of scores
"""
assert boxes.shape[-1] == 5
if boxes.numel() == 0:
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
boxes = boxes.float() # fp16 does not have enough range for batched NMS
# Strategy: in order to perform NMS independently per class,
# we add an offset to all the boxes. The offset is dependent
# only on the class idx, and is large enough so that boxes
# from different classes do not overlap
# Note that batched_nms in torchvision/ops/boxes.py only uses max_coordinate,
# which won't handle negative coordinates correctly.
# Here by using min_coordinate we can make sure the negative coordinates are
# correctly handled.
max_coordinate = (
torch.max(boxes[:, 0], boxes[:, 1]) + torch.max(boxes[:, 2], boxes[:, 3]) / 2
).max()
min_coordinate = (
torch.min(boxes[:, 0], boxes[:, 1]) - torch.max(boxes[:, 2], boxes[:, 3]) / 2
).min()
offsets = idxs.to(boxes) * (max_coordinate - min_coordinate + 1)
boxes_for_nms = boxes.clone() # avoid modifying the original values in boxes
boxes_for_nms[:, :2] += offsets[:, None]
keep = nms_rotated(boxes_for_nms, scores, iou_threshold)
return keep
# Copyright (c) Facebook, Inc. and its affiliates.
from torch import nn
from torchvision.ops import roi_align
# NOTE: torchvision's RoIAlign has a different default aligned=False
class ROIAlign(nn.Module):
def __init__(self, output_size, spatial_scale, sampling_ratio, aligned=True):
"""
Args:
output_size (tuple): h, w
spatial_scale (float): scale the input boxes by this number
sampling_ratio (int): number of inputs samples to take for each output
sample. 0 to take samples densely.
aligned (bool): if False, use the legacy implementation in
Detectron. If True, align the results more perfectly.
Note:
The meaning of aligned=True:
Given a continuous coordinate c, its two neighboring pixel indices (in our
pixel model) are computed by floor(c - 0.5) and ceil(c - 0.5). For example,
c=1.3 has pixel neighbors with discrete indices [0] and [1] (which are sampled
from the underlying signal at continuous coordinates 0.5 and 1.5). But the original
roi_align (aligned=False) does not subtract the 0.5 when computing neighboring
pixel indices and therefore it uses pixels with a slightly incorrect alignment
(relative to our pixel model) when performing bilinear interpolation.
With `aligned=True`,
we first appropriately scale the ROI and then shift it by -0.5
prior to calling roi_align. This produces the correct neighbors; see
detectron2/tests/test_roi_align.py for verification.
The difference does not make a difference to the model's performance if
ROIAlign is used together with conv layers.
"""
super().__init__()
self.output_size = output_size
self.spatial_scale = spatial_scale
self.sampling_ratio = sampling_ratio
self.aligned = aligned
from torchvision import __version__
version = tuple(int(x) for x in __version__.split(".")[:2])
# https://github.com/pytorch/vision/pull/2438
assert version >= (0, 7), "Require torchvision >= 0.7"
def forward(self, input, rois):
"""
Args:
input: NCHW images
rois: Bx5 boxes. First column is the index into N. The other 4 columns are xyxy.
"""
assert rois.dim() == 2 and rois.size(1) == 5
if input.is_quantized:
input = input.dequantize()
return roi_align(
input,
rois.to(dtype=input.dtype),
self.output_size,
self.spatial_scale,
self.sampling_ratio,
self.aligned,
)
def __repr__(self):
tmpstr = self.__class__.__name__ + "("
tmpstr += "output_size=" + str(self.output_size)
tmpstr += ", spatial_scale=" + str(self.spatial_scale)
tmpstr += ", sampling_ratio=" + str(self.sampling_ratio)
tmpstr += ", aligned=" + str(self.aligned)
tmpstr += ")"
return tmpstr
# Copyright (c) Facebook, Inc. and its affiliates.
import torch
from torch import nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from detectron2 import _C
class _ROIAlignRotated(Function):
@staticmethod
def forward(ctx, input, roi, output_size, spatial_scale, sampling_ratio):
ctx.save_for_backward(roi)
ctx.output_size = _pair(output_size)
ctx.spatial_scale = spatial_scale
ctx.sampling_ratio = sampling_ratio
ctx.input_shape = input.size()
output = _C.roi_align_rotated_forward(
input, roi, spatial_scale, output_size[0], output_size[1], sampling_ratio
)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
(rois,) = ctx.saved_tensors
output_size = ctx.output_size
spatial_scale = ctx.spatial_scale
sampling_ratio = ctx.sampling_ratio
bs, ch, h, w = ctx.input_shape
grad_input = _C.roi_align_rotated_backward(
grad_output,
rois,
spatial_scale,
output_size[0],
output_size[1],
bs,
ch,
h,
w,
sampling_ratio,
)
return grad_input, None, None, None, None, None
roi_align_rotated = _ROIAlignRotated.apply
class ROIAlignRotated(nn.Module):
def __init__(self, output_size, spatial_scale, sampling_ratio):
"""
Args:
output_size (tuple): h, w
spatial_scale (float): scale the input boxes by this number
sampling_ratio (int): number of inputs samples to take for each output
sample. 0 to take samples densely.
Note:
ROIAlignRotated supports continuous coordinate by default:
Given a continuous coordinate c, its two neighboring pixel indices (in our
pixel model) are computed by floor(c - 0.5) and ceil(c - 0.5). For example,
c=1.3 has pixel neighbors with discrete indices [0] and [1] (which are sampled
from the underlying signal at continuous coordinates 0.5 and 1.5).
"""
super(ROIAlignRotated, self).__init__()
self.output_size = output_size
self.spatial_scale = spatial_scale
self.sampling_ratio = sampling_ratio
def forward(self, input, rois):
"""
Args:
input: NCHW images
rois: Bx6 boxes. First column is the index into N.
The other 5 columns are (x_ctr, y_ctr, width, height, angle_degrees).
"""
assert rois.dim() == 2 and rois.size(1) == 6
orig_dtype = input.dtype
if orig_dtype == torch.float16:
input = input.float()
rois = rois.float()
return roi_align_rotated(
input, rois, self.output_size, self.spatial_scale, self.sampling_ratio
).to(dtype=orig_dtype)
def __repr__(self):
tmpstr = self.__class__.__name__ + "("
tmpstr += "output_size=" + str(self.output_size)
tmpstr += ", spatial_scale=" + str(self.spatial_scale)
tmpstr += ", sampling_ratio=" + str(self.sampling_ratio)
tmpstr += ")"
return tmpstr
# Copyright (c) Facebook, Inc. and its affiliates.
from __future__ import absolute_import, division, print_function, unicode_literals
from detectron2 import _C
def pairwise_iou_rotated(boxes1, boxes2):
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in
(x_center, y_center, width, height, angle) format.
Arguments:
boxes1 (Tensor[N, 5])
boxes2 (Tensor[M, 5])
Returns:
iou (Tensor[N, M]): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2
"""
return _C.box_iou_rotated(boxes1, boxes2)
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.
from collections import namedtuple
class ShapeSpec(namedtuple("_ShapeSpec", ["channels", "height", "width", "stride"])):
"""
A simple structure that contains basic shape specification about a tensor.
It is often used as the auxiliary inputs/outputs of models,
to complement the lack of shape inference ability among pytorch modules.
Attributes:
channels:
height:
width:
stride:
"""
def __new__(cls, channels=None, height=None, width=None, stride=None):
return super().__new__(cls, channels, height, width, stride)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment