Commit cc26cd81 authored by panning's avatar panning
Browse files

merge v0.16.0

parents f78f29f5 fbb4cc54
...@@ -60,7 +60,7 @@ void roi_align_forward_kernel_impl( ...@@ -60,7 +60,7 @@ void roi_align_forward_kernel_impl(
// When the grid is empty, output zeros. // When the grid is empty, output zeros.
const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
// we want to precalculate indices and weights shared by all chanels, // we want to precalculate indices and weights shared by all channels,
// this is the key point of optimization // this is the key point of optimization
std::vector<detail::PreCalc<T>> pre_calc( std::vector<detail::PreCalc<T>> pre_calc(
roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height); roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
......
...@@ -70,7 +70,7 @@ ...@@ -70,7 +70,7 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <torch/library.h> #include <torch/library.h>
#include <ATen/cuda/Atomic.cuh> #include <ATen/native/cuda/KernelUtils.cuh>
#include "cuda_helpers.h" #include "cuda_helpers.h"
...@@ -300,11 +300,7 @@ void deformable_im2col( ...@@ -300,11 +300,7 @@ void deformable_im2col(
data_col.data_ptr<scalar_t>()); data_col.data_ptr<scalar_t>());
})); }));
} }
C10_CUDA_KERNEL_LAUNCH_CHECK();
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
}
} }
int get_greatest_divisor_below_bound(int n, int bound) { int get_greatest_divisor_below_bound(int n, int bound) {
...@@ -339,6 +335,8 @@ __global__ void deformable_col2im_kernel( ...@@ -339,6 +335,8 @@ __global__ void deformable_col2im_kernel(
index_t out_w, index_t out_w,
bool use_mask, bool use_mask,
scalar_t* grad_im) { scalar_t* grad_im) {
const index_t grad_im_numel = width * height * channels * batch_sz;
CUDA_1D_KERNEL_LOOP_T(index, n, int64_t) { CUDA_1D_KERNEL_LOOP_T(index, n, int64_t) {
const index_t out_x = index % out_w; const index_t out_x = index % out_w;
const index_t out_y = (index / out_w) % out_h; const index_t out_y = (index / out_w) % out_h;
...@@ -385,7 +383,12 @@ __global__ void deformable_col2im_kernel( ...@@ -385,7 +383,12 @@ __global__ void deformable_col2im_kernel(
std::abs(y - yp) < 1 && std::abs(x - xp) < 1) { std::abs(y - yp) < 1 && std::abs(x - xp) < 1) {
index_t grad_pos = ((b * channels + c) * height + yp) * width + xp; index_t grad_pos = ((b * channels + c) * height + yp) * width + xp;
scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp)); scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp));
gpuAtomicAdd(grad_im + grad_pos, mask_value * weight * col[index]); at::native::fastAtomicAdd(
grad_im,
grad_pos,
grad_im_numel,
mask_value * weight * col[index],
true);
} }
} }
} }
...@@ -430,6 +433,8 @@ void compute_grad_input( ...@@ -430,6 +433,8 @@ void compute_grad_input(
// Checks if num_kernels or columns numel larger than 2 ** 31 // Checks if num_kernels or columns numel larger than 2 ** 31
use_64bits_indexing |= num_kernels > (1 << 31); use_64bits_indexing |= num_kernels > (1 << 31);
at::globalContext().alertNotDeterministic("compute_grad_input");
if (use_64bits_indexing) { if (use_64bits_indexing) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
columns.scalar_type(), "compute_grad_input", ([&] { columns.scalar_type(), "compute_grad_input", ([&] {
...@@ -483,11 +488,7 @@ void compute_grad_input( ...@@ -483,11 +488,7 @@ void compute_grad_input(
grad_im.data_ptr<scalar_t>()); grad_im.data_ptr<scalar_t>());
})); }));
} }
C10_CUDA_KERNEL_LAUNCH_CHECK();
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in compute_grad_input: %s\n", cudaGetErrorString(err));
}
} }
template <typename scalar_t, typename index_t> template <typename scalar_t, typename index_t>
...@@ -736,12 +737,7 @@ void compute_grad_offset_and_mask( ...@@ -736,12 +737,7 @@ void compute_grad_offset_and_mask(
grad_mask.data_ptr<scalar_t>()); grad_mask.data_ptr<scalar_t>());
})); }));
} }
C10_CUDA_KERNEL_LAUNCH_CHECK();
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf(
"error in compute_grad_offset_and_mask: %s\n", cudaGetErrorString(err));
}
} }
std::tuple<at::Tensor, at::Tensor, at::Tensor> backward_gradient_inputs( std::tuple<at::Tensor, at::Tensor, at::Tensor> backward_gradient_inputs(
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <torch/library.h> #include <torch/library.h>
#include <ATen/cuda/Atomic.cuh> #include <ATen/native/cuda/KernelUtils.cuh>
#include "cuda_helpers.h" #include "cuda_helpers.h"
...@@ -212,7 +212,8 @@ __global__ void ps_roi_align_backward_kernel_impl( ...@@ -212,7 +212,8 @@ __global__ void ps_roi_align_backward_kernel_impl(
int sampling_ratio, int sampling_ratio,
int channels_out, int channels_out,
T* grad_input, T* grad_input,
const T* rois) { const T* rois,
const int memory_span) {
CUDA_1D_KERNEL_LOOP(index, nthreads) { CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, *, ph, pw) is an element in the pooled output // (n, *, ph, pw) is an element in the pooled output
int pw = index % pooled_width; int pw = index % pooled_width;
...@@ -235,8 +236,6 @@ __global__ void ps_roi_align_backward_kernel_impl( ...@@ -235,8 +236,6 @@ __global__ void ps_roi_align_backward_kernel_impl(
T bin_size_w = roi_width / static_cast<T>(pooled_width); T bin_size_w = roi_width / static_cast<T>(pooled_width);
int c_in = channel_mapping[index]; int c_in = channel_mapping[index];
T* grad_input_offset =
grad_input + (roi_batch_ind * channels + c_in) * height * width;
// Do not using floor/ceil; this implementation detail is critical // Do not using floor/ceil; this implementation detail is critical
T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h; T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
...@@ -252,6 +251,8 @@ __global__ void ps_roi_align_backward_kernel_impl( ...@@ -252,6 +251,8 @@ __global__ void ps_roi_align_backward_kernel_impl(
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w; const T count = roi_bin_grid_h * roi_bin_grid_w;
const int offset = (roi_batch_ind * channels + c_in) * height * width;
for (int iy = 0; iy < roi_bin_grid_h; iy++) { for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = hstart + const T y = hstart +
static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(iy + .5f) * bin_size_h /
...@@ -285,10 +286,30 @@ __global__ void ps_roi_align_backward_kernel_impl( ...@@ -285,10 +286,30 @@ __global__ void ps_roi_align_backward_kernel_impl(
T g4 = grad_output_this_bin * w4 / count; T g4 = grad_output_this_bin * w4 / count;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
gpuAtomicAdd(grad_input_offset + y_low * width + x_low, g1); at::native::fastAtomicAdd(
gpuAtomicAdd(grad_input_offset + y_low * width + x_high, g2); grad_input,
gpuAtomicAdd(grad_input_offset + y_high * width + x_low, g3); offset + y_low * width + x_low,
gpuAtomicAdd(grad_input_offset + y_high * width + x_high, g4); memory_span,
static_cast<T>(g1),
true);
at::native::fastAtomicAdd(
grad_input,
offset + y_low * width + x_high,
memory_span,
static_cast<T>(g2),
true);
at::native::fastAtomicAdd(
grad_input,
offset + y_high * width + x_low,
memory_span,
static_cast<T>(g3),
true);
at::native::fastAtomicAdd(
grad_input,
offset + y_high * width + x_high,
memory_span,
static_cast<T>(g4),
true);
} // if } // if
} // ix } // ix
} // iy } // iy
...@@ -412,6 +433,8 @@ at::Tensor ps_roi_align_backward_kernel( ...@@ -412,6 +433,8 @@ at::Tensor ps_roi_align_backward_kernel(
int channels_out = channels / (pooled_height * pooled_width); int channels_out = channels / (pooled_height * pooled_width);
at::globalContext().alertNotDeterministic("ps_roi_align_backward_kernel");
auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); auto grad_ = grad.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "ps_roi_align_backward_kernel", [&] { grad.scalar_type(), "ps_roi_align_backward_kernel", [&] {
...@@ -428,7 +451,8 @@ at::Tensor ps_roi_align_backward_kernel( ...@@ -428,7 +451,8 @@ at::Tensor ps_roi_align_backward_kernel(
sampling_ratio, sampling_ratio,
channels_out, channels_out,
grad_input.data_ptr<scalar_t>(), grad_input.data_ptr<scalar_t>(),
rois_.data_ptr<scalar_t>()); rois_.data_ptr<scalar_t>(),
grad_input.numel());
}); });
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
return grad_input; return grad_input;
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <torch/library.h> #include <torch/library.h>
#include <ATen/cuda/Atomic.cuh> #include <ATen/native/cuda/KernelUtils.cuh>
#include "cuda_helpers.h" #include "cuda_helpers.h"
...@@ -91,7 +91,8 @@ __global__ void ps_roi_pool_backward_kernel_impl( ...@@ -91,7 +91,8 @@ __global__ void ps_roi_pool_backward_kernel_impl(
int pooled_width, int pooled_width,
int channels_out, int channels_out,
T* grad_input, T* grad_input,
const T* rois) { const T* rois,
const int memory_span) {
CUDA_1D_KERNEL_LOOP(index, nthreads) { CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, *, ph, pw) is an element in the pooled output // (n, *, ph, pw) is an element in the pooled output
int pw = index % pooled_width; int pw = index % pooled_width;
...@@ -124,14 +125,15 @@ __global__ void ps_roi_pool_backward_kernel_impl( ...@@ -124,14 +125,15 @@ __global__ void ps_roi_pool_backward_kernel_impl(
bool is_empty = (hend <= hstart) || (wend <= wstart); bool is_empty = (hend <= hstart) || (wend <= wstart);
int c_in = channel_mapping[index]; int c_in = channel_mapping[index];
T* grad_input_offset =
grad_input + (roi_batch_ind * channels + c_in) * height * width;
T bin_area = (hend - hstart) * (wend - wstart); T bin_area = (hend - hstart) * (wend - wstart);
T diff_val = is_empty ? static_cast<T>(0) : grad_output[index] / bin_area; T diff_val = is_empty ? static_cast<T>(0) : grad_output[index] / bin_area;
const int offset = (roi_batch_ind * channels + c_in) * height * width;
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) { for (int w = wstart; w < wend; ++w) {
int grad_input_index = h * width + w; int grad_input_index = h * width + w;
gpuAtomicAdd(grad_input_offset + grad_input_index, diff_val); at::native::fastAtomicAdd(
grad_input, offset + grad_input_index, memory_span, diff_val, true);
} }
} }
} }
...@@ -251,6 +253,8 @@ at::Tensor ps_roi_pool_backward_kernel( ...@@ -251,6 +253,8 @@ at::Tensor ps_roi_pool_backward_kernel(
int channels_out = channels / (pooled_height * pooled_width); int channels_out = channels / (pooled_height * pooled_width);
at::globalContext().alertNotDeterministic("ps_roi_pool_backward_kernel");
auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); auto grad_ = grad.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "ps_roi_pool_backward_kernel", [&] { grad.scalar_type(), "ps_roi_pool_backward_kernel", [&] {
...@@ -267,7 +271,8 @@ at::Tensor ps_roi_pool_backward_kernel( ...@@ -267,7 +271,8 @@ at::Tensor ps_roi_pool_backward_kernel(
pooled_width, pooled_width,
channels_out, channels_out,
grad_input.data_ptr<scalar_t>(), grad_input.data_ptr<scalar_t>(),
rois_.data_ptr<scalar_t>()); rois_.data_ptr<scalar_t>(),
grad_input.numel());
}); });
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
return grad_input; return grad_input;
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <torch/library.h> #include <torch/library.h>
#include <ATen/cuda/Atomic.cuh> #include <ATen/native/cuda/KernelUtils.cuh>
#include "cuda_helpers.h" #include "cuda_helpers.h"
...@@ -218,7 +218,8 @@ __global__ void roi_align_backward_kernel_impl( ...@@ -218,7 +218,8 @@ __global__ void roi_align_backward_kernel_impl(
int n_stride, int n_stride,
int c_stride, int c_stride,
int h_stride, int h_stride,
int w_stride) { int w_stride,
const int memory_span) {
CUDA_1D_KERNEL_LOOP(index, nthreads) { CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output // (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width; int pw = index % pooled_width;
...@@ -247,12 +248,9 @@ __global__ void roi_align_backward_kernel_impl( ...@@ -247,12 +248,9 @@ __global__ void roi_align_backward_kernel_impl(
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height); T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width); T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
T* offset_grad_input =
grad_input + ((roi_batch_ind * channels + c) * height * width);
// We need to index the gradient using the tensor strides to access the // We need to index the gradient using the tensor strides to access the
// correct values. // correct values.
int output_offset = n * n_stride + c * c_stride; const int output_offset = n * n_stride + c * c_stride;
const T* offset_grad_output = grad_output + output_offset; const T* offset_grad_output = grad_output + output_offset;
const T grad_output_this_bin = const T grad_output_this_bin =
offset_grad_output[ph * h_stride + pw * w_stride]; offset_grad_output[ph * h_stride + pw * w_stride];
...@@ -267,6 +265,8 @@ __global__ void roi_align_backward_kernel_impl( ...@@ -267,6 +265,8 @@ __global__ void roi_align_backward_kernel_impl(
// We do average (integral) pooling inside a bin // We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
const int input_offset = (roi_batch_ind * channels + c) * height * width;
for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
{ {
const T y = roi_start_h + ph * bin_size_h + const T y = roi_start_h + ph * bin_size_h +
...@@ -301,14 +301,30 @@ __global__ void roi_align_backward_kernel_impl( ...@@ -301,14 +301,30 @@ __global__ void roi_align_backward_kernel_impl(
T g4 = grad_output_this_bin * w4 / count; T g4 = grad_output_this_bin * w4 / count;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
gpuAtomicAdd( at::native::fastAtomicAdd(
offset_grad_input + y_low * width + x_low, static_cast<T>(g1)); grad_input,
gpuAtomicAdd( input_offset + y_low * width + x_low,
offset_grad_input + y_low * width + x_high, static_cast<T>(g2)); memory_span,
gpuAtomicAdd( static_cast<T>(g1),
offset_grad_input + y_high * width + x_low, static_cast<T>(g3)); true);
gpuAtomicAdd( at::native::fastAtomicAdd(
offset_grad_input + y_high * width + x_high, static_cast<T>(g4)); grad_input,
input_offset + y_low * width + x_high,
memory_span,
static_cast<T>(g2),
true);
at::native::fastAtomicAdd(
grad_input,
input_offset + y_high * width + x_low,
memory_span,
static_cast<T>(g3),
true);
at::native::fastAtomicAdd(
grad_input,
input_offset + y_high * width + x_high,
memory_span,
static_cast<T>(g4),
true);
} // if } // if
} // ix } // ix
} // iy } // iy
...@@ -421,6 +437,8 @@ at::Tensor roi_align_backward_kernel( ...@@ -421,6 +437,8 @@ at::Tensor roi_align_backward_kernel(
int h_stride = grad.stride(2); int h_stride = grad.stride(2);
int w_stride = grad.stride(3); int w_stride = grad.stride(3);
at::globalContext().alertNotDeterministic("roi_align_backward_kernel");
auto rois_ = rois.contiguous(); auto rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "roi_align_backward_kernel", [&] { grad.scalar_type(), "roi_align_backward_kernel", [&] {
...@@ -440,7 +458,8 @@ at::Tensor roi_align_backward_kernel( ...@@ -440,7 +458,8 @@ at::Tensor roi_align_backward_kernel(
n_stride, n_stride,
c_stride, c_stride,
h_stride, h_stride,
w_stride); w_stride,
grad_input.numel());
}); });
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
return grad_input; return grad_input;
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <float.h> #include <float.h>
#include <torch/library.h> #include <torch/library.h>
#include <ATen/cuda/Atomic.cuh> #include <ATen/native/cuda/KernelUtils.cuh>
#include "cuda_helpers.h" #include "cuda_helpers.h"
...@@ -94,7 +94,8 @@ __global__ void roi_pool_backward_kernel_impl( ...@@ -94,7 +94,8 @@ __global__ void roi_pool_backward_kernel_impl(
int n_stride, int n_stride,
int c_stride, int c_stride,
int h_stride, int h_stride,
int w_stride) { int w_stride,
const int memory_span) {
CUDA_1D_KERNEL_LOOP(index, nthreads) { CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output // (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width; int pw = index % pooled_width;
...@@ -104,19 +105,21 @@ __global__ void roi_pool_backward_kernel_impl( ...@@ -104,19 +105,21 @@ __global__ void roi_pool_backward_kernel_impl(
const T* offset_rois = rois + n * 5; const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0]; int roi_batch_ind = offset_rois[0];
T* grad_input_offset =
grad_input + ((roi_batch_ind * channels + c) * height * width);
int output_offset = n * n_stride + c * c_stride; const int output_offset = n * n_stride + c * c_stride;
const int* argmax_data_offset = const int* argmax_data_offset =
argmax_data + (n * channels + c) * pooled_height * pooled_width; argmax_data + (n * channels + c) * pooled_height * pooled_width;
int argmax = argmax_data_offset[ph * pooled_width + pw]; const int argmax = argmax_data_offset[ph * pooled_width + pw];
const int offset = (roi_batch_ind * channels + c) * height * width;
if (argmax != -1) { if (argmax != -1) {
gpuAtomicAdd( at::native::fastAtomicAdd(
grad_input_offset + argmax, grad_input,
offset + argmax,
memory_span,
static_cast<T>( static_cast<T>(
grad_output[output_offset + ph * h_stride + pw * w_stride])); grad_output[output_offset + ph * h_stride + pw * w_stride]),
true);
} }
} }
} }
...@@ -232,6 +235,8 @@ at::Tensor roi_pool_backward_kernel( ...@@ -232,6 +235,8 @@ at::Tensor roi_pool_backward_kernel(
int h_stride = grad.stride(2); int h_stride = grad.stride(2);
int w_stride = grad.stride(3); int w_stride = grad.stride(3);
at::globalContext().alertNotDeterministic("roi_pool_backward_kernel");
auto argmax_ = argmax.contiguous(), rois_ = rois.contiguous(); auto argmax_ = argmax.contiguous(), rois_ = rois.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "roi_pool_backward_kernel", [&] { grad.scalar_type(), "roi_pool_backward_kernel", [&] {
...@@ -251,7 +256,8 @@ at::Tensor roi_pool_backward_kernel( ...@@ -251,7 +256,8 @@ at::Tensor roi_pool_backward_kernel(
n_stride, n_stride,
c_stride, c_stride,
h_stride, h_stride,
w_stride); w_stride,
grad_input.numel());
}); });
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
return grad_input; return grad_input;
......
constexpr int threadsPerBlock = 512;
template <typename T>
constexpr inline T ceil_div(T n, T m) {
return (n + m - 1) / m;
}
#include <ATen/native/mps/OperationUtils.h>
namespace vision {
namespace ops {
namespace mps {
static const char* METAL_VISION = R"VISION_METAL(
#include <metal_atomic>
#include <metal_stdlib>
using namespace metal;
/*----------Macros----------*/
#define MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, index_t) \
for (index_t i = (tgid.x * tptg.x) + tid2.x; i < (n); \
i += (tptg.x * n_tgs))
#define MPS_1D_KERNEL_LOOP(i, n, n_tgs) MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, uint)
/*----------Helpers--------*/
template <typename T>
inline T ceil_div(T n, T m) {
return (n + m - 1) / m;
}
template <typename T>
inline void atomic_add_float( device T* data_ptr, const T val)
{
#if __METAL_VERSION__ >= 300
// atomic_float is supported in Metal 3 (macOS Ventura) onward.
device atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed);
#else
// Custom atomic addition implementation
// https://github.com/ShoYamanishi/AppleNumericalComputing/blob/053f06c1f5a831095c4bcc29aaf11366fce5231e/03_dot/metal/dot.metal#L447-L472
// https://forums.developer.nvidia.com/t/atomicadd-float-float-atomicmul-float-float/14639
// https://on-demand.gputechconf.com/gtc/2013/presentations/S3101-Atomic-Memory-Operations.pdf (See the last slide)
// Create an atomic uint pointer for atomic transaction.
device atomic_uint* atom_var = (device atomic_uint*)data_ptr;
// Create necessary storage.
uint fetched_uint, assigning_uint;
T fetched_float, assigning_float;
// Replace the value in atom_var with 0 and return the previous value in atom_var.
fetched_uint = atomic_exchange_explicit( atom_var, 0 /*desired*/, memory_order_relaxed);
// Read out the previous value as float.
fetched_float = *( (thread T*) &fetched_uint );
// Do addition and represent the addition result in uint for atomic transaction.
assigning_float = fetched_float + val;
assigning_uint = *((thread uint*) &assigning_float);
// atom_var should be 0 now, try to assign the addition result back to the atom_var (data_ptr).
while ((fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint /*desired*/, memory_order_relaxed)) != 0) {
// If atom_var was not 0, i.e. fetched_uint != 0, it means that the data has been modified by other threads.
// Try to assign 0 and get the previously assigned addition result.
uint fetched_uint_again = atomic_exchange_explicit(atom_var, 0 /*desired*/, memory_order_relaxed);
T fetched_float_again = *( (thread T*) &fetched_uint_again );
// Re-add again
fetched_float = *((thread T*) &(fetched_uint));
// Previously assigned addition result + addition result from other threads.
assigning_float = fetched_float_again + fetched_float;
assigning_uint = *( (thread uint*) &assigning_float);
}
#endif
}
template <typename T, typename integer_t>
inline T bilinear_interpolate(
constant T* input,
integer_t height,
integer_t width,
T y,
T x,
uint index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
return 0;
}
if (y <= 0)
y = 0;
if (x <= 0)
x = 0;
integer_t y_low = (integer_t)y;
integer_t x_low = (integer_t)x;
integer_t y_high;
integer_t x_high;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
// do bilinear interpolation
T v1 = input[y_low * width + x_low];
T v2 = input[y_low * width + x_high];
T v3 = input[y_high * width + x_low];
T v4 = input[y_high * width + x_high];
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename T, typename integer_t>
inline void bilinear_interpolate_gradient(
integer_t height,
integer_t width,
T y,
T x,
thread T& w1,
thread T& w2,
thread T& w3,
thread T& w4,
thread integer_t& x_low,
thread integer_t& x_high,
thread integer_t& y_low,
thread integer_t& y_high,
uint index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
w1 = w2 = w3 = w4 = 0.;
x_low = x_high = y_low = y_high = -1;
return;
}
if (y <= 0)
y = 0;
if (x <= 0)
x = 0;
y_low = (integer_t)y;
x_low = (integer_t)x;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
// reference in forward
// T v1 = input[y_low * width + x_low];
// T v2 = input[y_low * width + x_high];
// T v3 = input[y_high * width + x_low];
// T v4 = input[y_high * width + x_high];
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
}
template <typename T, typename scalar_t>
inline bool IoU(
constant T & a,
threadgroup T & b,
const float threshold) {
auto xx1 = max(a.x, b.x);
auto yy1 = max(a.y, b.y);
auto xx2 = min(a.z, b.z);
auto yy2 = min(a.w, b.w);
auto w = max(static_cast<scalar_t>(0), xx2 - xx1);
auto h = max(static_cast<scalar_t>(0), yy2 - yy1);
// Upcast to float before multiplications to circumvent precision issues in half.
auto inter = static_cast<float>(w) * static_cast<float>(h);
auto area_b = static_cast<float>(b.z - b.x) * static_cast<float>(b.w - b.y);
auto area_a = static_cast<float>(a.z - a.x) * static_cast<float>(a.w - a.y);
return (inter / (area_a + area_b - inter)) > threshold;
}
/*----------Kernels----------*/
// This should be in sync with the one in nms_kernel.mm.
// Since metal does not support dynamic array,
// we need to make it static instead of deriving it from [[threads_per_threadgroup]].
constant int64_t nmsThreadsPerBlock = sizeof(uint64_t) * 8;
template<typename T, typename scalar_t>
kernel void nms(constant T * dev_boxes [[buffer(0)]],
device uint64_t * mask [[buffer(1)]],
constant int64_t & n_boxes [[buffer(2)]],
constant float & iou_threshold [[buffer(3)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tid2 [[thread_position_in_threadgroup]]) {
const uint row_start = tgid.y;
const uint col_start = tgid.x;
const uint tid = tid2.x;
const uint row_size =
min(n_boxes - row_start * nmsThreadsPerBlock, nmsThreadsPerBlock);
const uint col_size =
min(n_boxes - col_start * nmsThreadsPerBlock, nmsThreadsPerBlock);
threadgroup T block_boxes[nmsThreadsPerBlock];
block_boxes[tid] = dev_boxes[nmsThreadsPerBlock * col_start + tid];
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid < row_size) {
const uint cur_box_idx = nmsThreadsPerBlock * row_start + tid;
uint64_t t = 0;
uint start = 0;
if (row_start == col_start) {
start = tid + 1;
}
for (uint i = start; i < col_size; i++){
if (IoU<T, scalar_t>(dev_boxes[cur_box_idx], block_boxes[i], iou_threshold)){
t |= static_cast<uint64_t>(1) << i; // discard 1 keep 0
}
}
const uint col_blocks = ceil_div(n_boxes, nmsThreadsPerBlock);
mask[cur_box_idx * col_blocks + col_start] = t;
}
}
#define REGISTER_NMS_OP(DTYPE) \
template \
[[host_name("nms_" #DTYPE)]] \
kernel void nms<DTYPE ## 4, DTYPE>( \
constant DTYPE ## 4 * dev_boxes [[buffer(0)]], \
device uint64_t * mask [[buffer(1)]], \
constant int64_t & n_boxes [[buffer(2)]], \
constant float & iou_threshold [[buffer(3)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
template<typename T, typename integer_t>
kernel void roi_align(
constant T * input [[buffer(0)]],
constant T * rois [[buffer(1)]],
device T * output [[buffer(2)]],
constant int64_t & output_size [[buffer(3)]],
constant int64_t & channels [[buffer(4)]],
constant int64_t & height [[buffer(5)]],
constant int64_t & width [[buffer(6)]],
constant int64_t & pooled_height [[buffer(7)]],
constant int64_t & pooled_width [[buffer(8)]],
constant int64_t & sampling_ratio [[buffer(9)]],
constant bool & aligned [[buffer(10)]],
constant float & spatial_scale [[buffer(11)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, c, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t c = (index / pooled_width / pooled_height) % channels;
integer_t n = index / pooled_width / pooled_height / channels;
constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T offset = aligned ? (T)0.5 : (T)0.0;
T roi_start_w = offset_rois[1] * spatial_scale - offset;
T roi_start_h = offset_rois[2] * spatial_scale - offset;
T roi_end_w = offset_rois[3] * spatial_scale - offset;
T roi_end_h = offset_rois[4] * spatial_scale - offset;
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
if (!aligned) {
// Force malformed ROIs to be 1x1
roi_width = max(roi_width, (T)1.);
roi_height = max(roi_height, (T)1.);
}
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
constant T* offset_input =
input + (roi_batch_ind * channels + c) * height * width;
// We use roi_bin_grid to sample the grid and mimic integral
integer_t roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
integer_t roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
// We do average (integral) pooling inside a bin
// When the grid is empty, output zeros.
const T count = max(roi_bin_grid_h * roi_bin_grid_w, static_cast<integer_t>(1)); // e.g. = 4
T output_val = 0.;
for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
{
const T y = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T val = bilinear_interpolate(offset_input, height, width, y, x, index);
output_val += val;
}
}
output_val /= count;
output[index] = output_val;
}
}
#define REGISTER_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("roi_align_" #DTYPE)]] \
kernel void roi_align<DTYPE, INT_DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
device DTYPE * output [[buffer(2)]], \
constant int64_t & output_size [[buffer(3)]], \
constant int64_t & channels [[buffer(4)]], \
constant int64_t & height [[buffer(5)]], \
constant int64_t & width [[buffer(6)]], \
constant int64_t & pooled_height [[buffer(7)]], \
constant int64_t & pooled_width [[buffer(8)]], \
constant int64_t & sampling_ratio [[buffer(9)]], \
constant bool & aligned [[buffer(10)]], \
constant float & spatial_scale [[buffer(11)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
template<typename T, typename integer_t>
kernel void roi_align_backward(
constant T * grad_output [[buffer(0)]],
constant T * rois [[buffer(1)]],
device T * grad_input [[buffer(2)]],
constant int64_t & output_size [[buffer(3)]],
constant int64_t & channels [[buffer(4)]],
constant int64_t & height [[buffer(5)]],
constant int64_t & width [[buffer(6)]],
constant int64_t & pooled_height [[buffer(7)]],
constant int64_t & pooled_width [[buffer(8)]],
constant int64_t & sampling_ratio [[buffer(9)]],
constant bool & aligned [[buffer(10)]],
constant float & spatial_scale [[buffer(11)]],
constant int64_t & n_stride [[buffer(12)]],
constant int64_t & c_stride [[buffer(13)]],
constant int64_t & h_stride [[buffer(14)]],
constant int64_t & w_stride [[buffer(15)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, c, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t c = (index / pooled_width / pooled_height) % channels;
integer_t n = index / pooled_width / pooled_height / channels;
constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T offset = aligned ? (T)0.5 : (T)0.0;
T roi_start_w = offset_rois[1] * spatial_scale - offset;
T roi_start_h = offset_rois[2] * spatial_scale - offset;
T roi_end_w = offset_rois[3] * spatial_scale - offset;
T roi_end_h = offset_rois[4] * spatial_scale - offset;
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
if (!aligned) {
// Force malformed ROIs to be 1x1
roi_width = max(roi_width, (T)1.);
roi_height = max(roi_height, (T)1.);
}
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
// We need to index the gradient using the tensor strides to access the
// correct values.
const integer_t output_offset = n * n_stride + c * c_stride;
constant T* offset_grad_output = grad_output + output_offset;
const T grad_output_this_bin =
offset_grad_output[ph * h_stride + pw * w_stride];
// We use roi_bin_grid to sample the grid and mimic integral
integer_t roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
integer_t roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
// We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
const integer_t input_offset = (roi_batch_ind * channels + c) * height * width;
for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
{
const T y = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T w1, w2, w3, w4;
integer_t x_low, x_high, y_low, y_high;
bilinear_interpolate_gradient(
height,
width,
y,
x,
w1,
w2,
w3,
w4,
x_low,
x_high,
y_low,
y_high,
index);
T g1 = grad_output_this_bin * w1 / count;
T g2 = grad_output_this_bin * w2 / count;
T g3 = grad_output_this_bin * w3 / count;
T g4 = grad_output_this_bin * w4 / count;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
atomic_add_float(grad_input + input_offset + y_low * width + x_low, static_cast<T>(g1));
atomic_add_float(grad_input + input_offset + y_low * width + x_high, static_cast<T>(g2));
atomic_add_float(grad_input + input_offset + y_high * width + x_low, static_cast<T>(g3));
atomic_add_float(grad_input + input_offset + y_high * width + x_high, static_cast<T>(g4));
} // if
} // ix
} // iy
} // MPS_1D_KERNEL_LOOP
}
#define REGISTER_ROI_ALIGN_BACKWARD_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("roi_align_backward_" #DTYPE)]] \
kernel void roi_align_backward<DTYPE, INT_DTYPE>( \
constant DTYPE * grad_output [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
device DTYPE * grad_input [[buffer(2)]], \
constant int64_t & output_size [[buffer(3)]], \
constant int64_t & channels [[buffer(4)]], \
constant int64_t & height [[buffer(5)]], \
constant int64_t & width [[buffer(6)]], \
constant int64_t & pooled_height [[buffer(7)]], \
constant int64_t & pooled_width [[buffer(8)]], \
constant int64_t & sampling_ratio [[buffer(9)]], \
constant bool & aligned [[buffer(10)]], \
constant float & spatial_scale [[buffer(11)]], \
constant int64_t & n_stride [[buffer(12)]], \
constant int64_t & c_stride [[buffer(13)]], \
constant int64_t & h_stride [[buffer(14)]], \
constant int64_t & w_stride [[buffer(15)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
template<typename T, typename integer_t>
kernel void roi_pool(
constant T * input [[buffer(0)]],
constant T * rois [[buffer(1)]],
device T * output [[buffer(2)]],
device int64_t * argmax [[buffer(3)]],
constant int64_t & output_size [[buffer(4)]],
constant int64_t & channels [[buffer(5)]],
constant int64_t & height [[buffer(6)]],
constant int64_t & width [[buffer(7)]],
constant int64_t & pooled_height [[buffer(8)]],
constant int64_t & pooled_width [[buffer(9)]],
constant float & spatial_scale [[buffer(10)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, c, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t c = (index / pooled_width / pooled_height) % channels;
integer_t n = index / pooled_width / pooled_height / channels;
constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];
integer_t roi_start_w = round(offset_rois[1] * spatial_scale);
integer_t roi_start_h = round(offset_rois[2] * spatial_scale);
integer_t roi_end_w = round(offset_rois[3] * spatial_scale);
integer_t roi_end_h = round(offset_rois[4] * spatial_scale);
// Force malformed ROIs to be 1x1
integer_t roi_width = max(roi_end_w - roi_start_w + 1, static_cast<integer_t>(1));
integer_t roi_height = max(roi_end_h - roi_start_h + 1, static_cast<integer_t>(1));
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
integer_t hstart = static_cast<integer_t>(floor(static_cast<T>(ph) * bin_size_h));
integer_t wstart = static_cast<integer_t>(floor(static_cast<T>(pw) * bin_size_w));
integer_t hend = static_cast<integer_t>(ceil(static_cast<T>(ph + 1) * bin_size_h));
integer_t wend = static_cast<integer_t>(ceil(static_cast<T>(pw + 1) * bin_size_w));
// Add roi offsets and clip to input boundaries
hstart = min(max(hstart + roi_start_h, static_cast<integer_t>(0)), static_cast<integer_t>(height));
hend = min(max(hend + roi_start_h, static_cast<integer_t>(0)), static_cast<integer_t>(height));
wstart = min(max(wstart + roi_start_w, static_cast<integer_t>(0)), static_cast<integer_t>(width));
wend = min(max(wend + roi_start_w, static_cast<integer_t>(0)), static_cast<integer_t>(width));
bool is_empty = (hend <= hstart) || (wend <= wstart);
// Define an empty pooling region to be zero
T maxval = is_empty ? 0 : -FLT_MAX;
// If nothing is pooled, argmax = -1 causes nothing to be backprop'd
integer_t maxidx = -1;
constant T* offset_input =
input + (roi_batch_ind * channels + c) * height * width;
for (integer_t h = hstart; h < hend; ++h) {
for (integer_t w = wstart; w < wend; ++w) {
integer_t input_index = h * width + w;
if (offset_input[input_index] > maxval) {
maxval = offset_input[input_index];
maxidx = input_index;
}
}
}
output[index] = maxval;
argmax[index] = maxidx;
}
}
#define REGISTER_ROI_POOL_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("roi_pool_" #DTYPE)]] \
kernel void roi_pool<DTYPE, INT_DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
device DTYPE * output [[buffer(2)]], \
device int64_t * argmax_data [[buffer(3)]], \
constant int64_t & output_size [[buffer(4)]], \
constant int64_t & channels [[buffer(5)]], \
constant int64_t & height [[buffer(6)]], \
constant int64_t & width [[buffer(7)]], \
constant int64_t & pooled_height [[buffer(8)]], \
constant int64_t & pooled_width [[buffer(9)]], \
constant float & spatial_scale [[buffer(10)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
template<typename T, typename integer_t>
kernel void roi_pool_backward(
constant T * grad_output [[buffer(0)]],
constant T * rois [[buffer(1)]],
constant int64_t * argmax_data [[buffer(2)]],
device T * grad_input [[buffer(3)]],
constant int64_t & output_size [[buffer(4)]],
constant int64_t & channels [[buffer(5)]],
constant int64_t & height [[buffer(6)]],
constant int64_t & width [[buffer(7)]],
constant int64_t & pooled_height [[buffer(8)]],
constant int64_t & pooled_width [[buffer(9)]],
constant float & spatial_scale [[buffer(10)]],
constant int64_t & n_stride [[buffer(11)]],
constant int64_t & c_stride [[buffer(12)]],
constant int64_t & h_stride [[buffer(13)]],
constant int64_t & w_stride [[buffer(14)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, c, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t c = (index / pooled_width / pooled_height) % channels;
integer_t n = index / pooled_width / pooled_height / channels;
constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];
const integer_t output_offset = n * n_stride + c * c_stride;
constant integer_t * argmax_data_offset =
argmax_data + (n * channels + c) * pooled_height * pooled_width;
const integer_t argmax = argmax_data_offset[ph * pooled_width + pw];
const integer_t offset = (roi_batch_ind * channels + c) * height * width;
if (argmax != -1) {
atomic_add_float(grad_input + offset + argmax, static_cast<T>(grad_output[output_offset + ph * h_stride + pw * w_stride]));
}
} // MPS_1D_KERNEL_LOOP
}
#define REGISTER_ROI_POOL_BACKWARD_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("roi_pool_backward_" #DTYPE)]] \
kernel void roi_pool_backward<DTYPE, INT_DTYPE>( \
constant DTYPE * grad_output [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
constant int64_t * argmax_data [[buffer(2)]], \
device DTYPE * grad_input [[buffer(3)]], \
constant int64_t & output_size [[buffer(4)]], \
constant int64_t & channels [[buffer(5)]], \
constant int64_t & height [[buffer(6)]], \
constant int64_t & width [[buffer(7)]], \
constant int64_t & pooled_height [[buffer(8)]], \
constant int64_t & pooled_width [[buffer(9)]], \
constant float & spatial_scale [[buffer(10)]], \
constant int64_t & n_stride [[buffer(11)]], \
constant int64_t & c_stride [[buffer(12)]], \
constant int64_t & h_stride [[buffer(13)]], \
constant int64_t & w_stride [[buffer(14)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
template<typename T, typename integer_t>
kernel void ps_roi_align(
constant T * input [[buffer(0)]],
constant T * rois [[buffer(1)]],
device T * output [[buffer(2)]],
device int64_t * channel_mapping [[buffer(3)]],
constant int64_t & output_size [[buffer(4)]],
constant int64_t & channels [[buffer(5)]],
constant int64_t & height [[buffer(6)]],
constant int64_t & width [[buffer(7)]],
constant int64_t & pooled_height [[buffer(8)]],
constant int64_t & pooled_width [[buffer(9)]],
constant int64_t & sampling_ratio [[buffer(10)]],
constant int64_t & channels_out [[buffer(11)]],
constant float & spatial_scale [[buffer(12)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, c_out, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t c_out = (index / pooled_width / pooled_height) % channels_out;
integer_t n = index / pooled_width / pooled_height / channels_out;
// (n, c_in, ph, pw) is the associated element in the input
integer_t c_in = (c_out * pooled_height + ph) * pooled_width + pw;
// [start, end) interval for spatial sampling
constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_rois[1] * spatial_scale - static_cast<T>(0.5);
T roi_start_h = offset_rois[2] * spatial_scale - static_cast<T>(0.5);
T roi_end_w = offset_rois[3] * spatial_scale - static_cast<T>(0.5);
T roi_end_h = offset_rois[4] * spatial_scale - static_cast<T>(0.5);
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
// Do not using floor/ceil; this implementation detail is critical
T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;
// We use roi_bin_grid to sample the grid and mimic integral
integer_t roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height);
integer_t roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w;
constant T* offset_input =
input + (roi_batch_ind * channels + c_in) * height * width;
T out_sum = 0;
for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = hstart +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = wstart +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T val = bilinear_interpolate(offset_input, height, width, y, x, index);
out_sum += val;
}
}
out_sum /= count;
output[index] = out_sum;
channel_mapping[index] = c_in;
}
}
#define REGISTER_PS_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("ps_roi_align_" #DTYPE)]] \
kernel void ps_roi_align<DTYPE, INT_DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
device DTYPE * output [[buffer(2)]], \
device int64_t * channel_mapping [[buffer(3)]], \
constant int64_t & output_size [[buffer(4)]], \
constant int64_t & channels [[buffer(5)]], \
constant int64_t & height [[buffer(6)]], \
constant int64_t & width [[buffer(7)]], \
constant int64_t & pooled_height [[buffer(8)]], \
constant int64_t & pooled_width [[buffer(9)]], \
constant int64_t & sampling_ratio [[buffer(10)]], \
constant int64_t & channels_out [[buffer(11)]], \
constant float & spatial_scale [[buffer(12)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
template<typename T, typename integer_t>
kernel void ps_roi_align_backward(
constant T * grad_output [[buffer(0)]],
constant T * rois [[buffer(1)]],
constant int64_t * channel_mapping [[buffer(2)]],
device T * grad_input [[buffer(3)]],
constant int64_t & output_size [[buffer(4)]],
constant int64_t & channels [[buffer(5)]],
constant int64_t & height [[buffer(6)]],
constant int64_t & width [[buffer(7)]],
constant int64_t & pooled_height [[buffer(8)]],
constant int64_t & pooled_width [[buffer(9)]],
constant int64_t & sampling_ratio [[buffer(10)]],
constant int64_t & channels_out [[buffer(11)]],
constant float & spatial_scale [[buffer(12)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, *, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t n = index / pooled_width / pooled_height / channels_out;
constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_rois[1] * spatial_scale - static_cast<T>(0.5);
T roi_start_h = offset_rois[2] * spatial_scale - static_cast<T>(0.5);
T roi_end_w = offset_rois[3] * spatial_scale - static_cast<T>(0.5);
T roi_end_h = offset_rois[4] * spatial_scale - static_cast<T>(0.5);
// Force too small ROIs to be 1x1
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
integer_t c_in = channel_mapping[index];
// Do not using floor/ceil; this implementation detail is critical
T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;
const T grad_output_this_bin = grad_output[index];
// We use roi_bin_grid to sample the grid and mimic integral
integer_t roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
integer_t roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w;
const integer_t offset = (roi_batch_ind * channels + c_in) * height * width;
for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = hstart +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = wstart +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T w1, w2, w3, w4;
integer_t x_low, x_high, y_low, y_high;
bilinear_interpolate_gradient(
height,
width,
y,
x,
w1,
w2,
w3,
w4,
x_low,
x_high,
y_low,
y_high,
index);
T g1 = grad_output_this_bin * w1 / count;
T g2 = grad_output_this_bin * w2 / count;
T g3 = grad_output_this_bin * w3 / count;
T g4 = grad_output_this_bin * w4 / count;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
atomic_add_float(grad_input + offset + y_low * width + x_low, static_cast<T>(g1));
atomic_add_float(grad_input + offset + y_low * width + x_high, static_cast<T>(g2));
atomic_add_float(grad_input + offset + y_high * width + x_low, static_cast<T>(g3));
atomic_add_float(grad_input + offset + y_high * width + x_high, static_cast<T>(g4));
} // if
} // ix
} // iy
}
}
#define REGISTER_PS_ROI_ALIGN_BACKWARD_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("ps_roi_align_backward_" #DTYPE)]] \
kernel void ps_roi_align_backward<DTYPE, INT_DTYPE>( \
constant DTYPE * grad_output [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
constant int64_t * channel_mapping [[buffer(2)]], \
device DTYPE * grad_input [[buffer(3)]], \
constant int64_t & output_size [[buffer(4)]], \
constant int64_t & channels [[buffer(5)]], \
constant int64_t & height [[buffer(6)]], \
constant int64_t & width [[buffer(7)]], \
constant int64_t & pooled_height [[buffer(8)]], \
constant int64_t & pooled_width [[buffer(9)]], \
constant int64_t & sampling_ratio [[buffer(10)]], \
constant int64_t & channels_out [[buffer(11)]], \
constant float & spatial_scale [[buffer(12)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
template<typename T, typename integer_t>
kernel void ps_roi_pool(
constant T * input [[buffer(0)]],
constant T * rois [[buffer(1)]],
device T * output [[buffer(2)]],
device int64_t * channel_mapping [[buffer(3)]],
constant int64_t & output_size [[buffer(4)]],
constant int64_t & channels [[buffer(5)]],
constant int64_t & height [[buffer(6)]],
constant int64_t & width [[buffer(7)]],
constant int64_t & pooled_height [[buffer(8)]],
constant int64_t & pooled_width [[buffer(9)]],
constant int64_t & channels_out [[buffer(10)]],
constant float & spatial_scale [[buffer(11)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, c_out, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t c_out = (index / (pooled_width * pooled_height)) % channels_out;
integer_t n = index / pooled_width / pooled_height / channels_out;
// (n, c_in, ph, pw) is the associated element in the input
integer_t c_in = (c_out * pooled_height + ph) * pooled_width + pw;
// [start, end) interval for spatial sampling
constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];
integer_t roi_start_w = round(offset_rois[1] * spatial_scale);
integer_t roi_start_h = round(offset_rois[2] * spatial_scale);
integer_t roi_end_w = round(offset_rois[3] * spatial_scale);
integer_t roi_end_h = round(offset_rois[4] * spatial_scale);
// Force too small ROIs to be 1x1
integer_t roi_width = max(roi_end_w - roi_start_w, static_cast<integer_t>(1));
integer_t roi_height = max(roi_end_h - roi_start_h, static_cast<integer_t>(1));
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
integer_t hstart = static_cast<integer_t>(floor(static_cast<T>(ph) * bin_size_h));
integer_t wstart = static_cast<integer_t>(floor(static_cast<T>(pw) * bin_size_w));
integer_t hend = static_cast<integer_t>(ceil(static_cast<T>(ph + 1) * bin_size_h));
integer_t wend = static_cast<integer_t>(ceil(static_cast<T>(pw + 1) * bin_size_w));
// Add roi offsets and clip to input boundaries
hstart = min(max(hstart + roi_start_h, static_cast<integer_t>(0)), static_cast<integer_t>(height - 1));
hend = min(max(hend + roi_start_h, static_cast<integer_t>(0)), static_cast<integer_t>(height - 1));
wstart = min(max(wstart + roi_start_w, static_cast<integer_t>(0)), static_cast<integer_t>(width - 1));
wend = min(max(wend + roi_start_w, static_cast<integer_t>(0)), static_cast<integer_t>(width - 1));
bool is_empty = (hend <= hstart) || (wend <= wstart);
constant T* offset_input =
input + (roi_batch_ind * channels + c_in) * height * width;
T out_sum = 0;
for (integer_t h = hstart; h < hend; ++h) {
for (integer_t w = wstart; w < wend; ++w) {
integer_t input_index = h * width + w;
out_sum += offset_input[input_index];
}
}
T bin_area = (hend - hstart) * (wend - wstart);
output[index] = is_empty ? static_cast<T>(0) : out_sum / bin_area;
channel_mapping[index] = c_in;
}
}
#define REGISTER_PS_ROI_POOL_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("ps_roi_pool_" #DTYPE)]] \
kernel void ps_roi_pool<DTYPE, INT_DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
device DTYPE * output [[buffer(2)]], \
device int64_t * channel_mapping [[buffer(3)]], \
constant int64_t & output_size [[buffer(4)]], \
constant int64_t & channels [[buffer(5)]], \
constant int64_t & height [[buffer(6)]], \
constant int64_t & width [[buffer(7)]], \
constant int64_t & pooled_height [[buffer(8)]], \
constant int64_t & pooled_width [[buffer(9)]], \
constant int64_t & channels_out [[buffer(10)]], \
constant float & spatial_scale [[buffer(11)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
template<typename T, typename integer_t>
kernel void ps_roi_pool_backward(
constant T * grad_output [[buffer(0)]],
constant T * rois [[buffer(1)]],
constant int64_t * channel_mapping [[buffer(2)]],
device T * grad_input [[buffer(3)]],
constant int64_t & output_size [[buffer(4)]],
constant int64_t & channels [[buffer(5)]],
constant int64_t & height [[buffer(6)]],
constant int64_t & width [[buffer(7)]],
constant int64_t & pooled_height [[buffer(8)]],
constant int64_t & pooled_width [[buffer(9)]],
constant int64_t & channels_out [[buffer(10)]],
constant float & spatial_scale [[buffer(11)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, *, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t n = index / pooled_width / pooled_height / channels_out;
constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];
integer_t roi_start_w = round(offset_rois[1] * spatial_scale);
integer_t roi_start_h = round(offset_rois[2] * spatial_scale);
integer_t roi_end_w = round(offset_rois[3] * spatial_scale);
integer_t roi_end_h = round(offset_rois[4] * spatial_scale);
// Force too small ROIs to be 1x1
integer_t roi_width = max(roi_end_w - roi_start_w, static_cast<integer_t>(1));
integer_t roi_height = max(roi_end_h - roi_start_h, static_cast<integer_t>(1));
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
integer_t hstart = static_cast<integer_t>(floor(static_cast<T>(ph) * bin_size_h));
integer_t wstart = static_cast<integer_t>(floor(static_cast<T>(pw) * bin_size_w));
integer_t hend = static_cast<integer_t>(ceil(static_cast<T>(ph + 1) * bin_size_h));
integer_t wend = static_cast<integer_t>(ceil(static_cast<T>(pw + 1) * bin_size_w));
// Add roi offsets and clip to input boundaries
hstart = min(max(hstart + roi_start_h, static_cast<integer_t>(0)), static_cast<integer_t>(height));
hend = min(max(hend + roi_start_h, static_cast<integer_t>(0)), static_cast<integer_t>(height));
wstart = min(max(wstart + roi_start_w, static_cast<integer_t>(0)), static_cast<integer_t>(width));
wend = min(max(wend + roi_start_w, static_cast<integer_t>(0)), static_cast<integer_t>(width));
bool is_empty = (hend <= hstart) || (wend <= wstart);
integer_t c_in = channel_mapping[index];
T bin_area = (hend - hstart) * (wend - wstart);
T diff_val = is_empty ? static_cast<T>(0) : grad_output[index] / bin_area;
const integer_t offset = (roi_batch_ind * channels + c_in) * height * width;
for (integer_t h = hstart; h < hend; ++h) {
for (integer_t w = wstart; w < wend; ++w) {
integer_t grad_input_index = h * width + w;
atomic_add_float(grad_input + offset + grad_input_index, diff_val);
}
}
} // MPS_1D_KERNEL_LOOP
}
#define REGISTER_PS_ROI_POOL_BACKWARD_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("ps_roi_pool_backward_" #DTYPE)]] \
kernel void ps_roi_pool_backward<DTYPE, INT_DTYPE>( \
constant DTYPE * grad_output [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
constant int64_t * channel_mapping [[buffer(2)]], \
device DTYPE * grad_input [[buffer(3)]], \
constant int64_t & output_size [[buffer(4)]], \
constant int64_t & channels [[buffer(5)]], \
constant int64_t & height [[buffer(6)]], \
constant int64_t & width [[buffer(7)]], \
constant int64_t & pooled_height [[buffer(8)]], \
constant int64_t & pooled_width [[buffer(9)]], \
constant int64_t & channels_out [[buffer(10)]], \
constant float & spatial_scale [[buffer(11)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
REGISTER_NMS_OP(float);
REGISTER_NMS_OP(half);
REGISTER_ROI_ALIGN_OP(float, int64_t);
REGISTER_ROI_ALIGN_OP(half, int64_t);
REGISTER_ROI_ALIGN_BACKWARD_OP(float, int64_t);
REGISTER_ROI_ALIGN_BACKWARD_OP(half, int64_t);
REGISTER_ROI_POOL_OP(float, int64_t);
REGISTER_ROI_POOL_OP(half, int64_t);
REGISTER_ROI_POOL_BACKWARD_OP(float, int64_t);
REGISTER_ROI_POOL_BACKWARD_OP(half, int64_t);
REGISTER_PS_ROI_ALIGN_OP(float, int64_t);
REGISTER_PS_ROI_ALIGN_OP(half, int64_t);
REGISTER_PS_ROI_ALIGN_BACKWARD_OP(float, int64_t);
REGISTER_PS_ROI_ALIGN_BACKWARD_OP(half, int64_t);
REGISTER_PS_ROI_POOL_OP(float, int64_t);
REGISTER_PS_ROI_POOL_OP(half, int64_t);
REGISTER_PS_ROI_POOL_BACKWARD_OP(float, int64_t);
REGISTER_PS_ROI_POOL_BACKWARD_OP(half, int64_t);
)VISION_METAL";
static id<MTLLibrary> compileVisionOpsLibrary(id<MTLDevice> device) {
static id<MTLLibrary> visionLibrary = nil;
if (visionLibrary) {
return visionLibrary;
}
NSError* error = nil;
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:MTLLanguageVersion2_3];
visionLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_VISION encoding:NSASCIIStringEncoding]
options:options
error:&error];
TORCH_CHECK(visionLibrary, "Failed to create metal vision library, error: ", [[error description] UTF8String]);
return visionLibrary;
}
static id<MTLComputePipelineState> visionPipelineState(id<MTLDevice> device, const std::string& kernel) {
static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
id<MTLComputePipelineState> pso = psoCache[kernel];
if (pso) {
return pso;
}
NSError* error = nil;
id<MTLLibrary> visionLib = compileVisionOpsLibrary(device);
id<MTLFunction> visionFunc = [visionLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
TORCH_CHECK(visionFunc, "Failed to create function state object for: ", kernel);
pso = [device newComputePipelineStateWithFunction:visionFunc error:&error];
TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
psoCache[kernel] = pso;
return pso;
}
} // namespace mps
} // namespace ops
} // namespace vision
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/OperationUtils.h>
#include "mps_kernels.h"
namespace vision {
namespace ops {
namespace {
// This should be in sync with `nmsThreadsPerBlock` in the metal kernel.
constexpr int64_t nmsThreadsPerBlock = sizeof(uint64_t) * 8;
at::Tensor nms_kernel(const at::Tensor& dets, const at::Tensor& scores, double iou_threshold) {
using namespace at::native::mps;
TORCH_CHECK(dets.is_mps(), "dets must be a MPS tensor");
TORCH_CHECK(scores.is_mps(), "scores must be a MPS tensor");
TORCH_CHECK(dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D");
TORCH_CHECK(dets.size(1) == 4, "boxes should have 4 elements in dimension 1, got ", dets.size(1));
TORCH_CHECK(scores.dim() == 1, "scores should be a 1d tensor, got ", scores.dim(), "D");
TORCH_CHECK(dets.size(0) == scores.size(0),
"boxes and scores should have same number of elements in ",
"dimension 0, got ",
dets.size(0),
" and ",
scores.size(0))
if (dets.numel() == 0) {
return at::empty({0}, dets.options().dtype(at::kLong));
}
auto order_t = std::get<1>(scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));
auto dets_sorted = dets.index_select(0, order_t).contiguous();
int64_t dets_num = dets.size(0);
float iou_threshold_f = static_cast<float>(iou_threshold);
const int col_blocks = (dets_num + nmsThreadsPerBlock - 1) / nmsThreadsPerBlock;
at::Tensor mask = at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong));
id<MTLBuffer> inputBuffer = getMTLBufferStorage(dets_sorted);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(mask);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(col_blocks, col_blocks, 1);
const std::string kernel = "nms_" + scalarToMetalTypeString(dets_sorted.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {dets, scores});
[computeEncoder setComputePipelineState:visionPSO];
[computeEncoder setBuffer:inputBuffer offset:dets_sorted.storage_offset() * dets_sorted.element_size() atIndex:0];
[computeEncoder setBuffer:outputBuffer offset:mask.storage_offset() * mask.element_size() atIndex:1];
[computeEncoder setBytes:&dets_num length:sizeof(int64_t) atIndex:2];
[computeEncoder setBytes:&iou_threshold_f length:sizeof(float) atIndex:3];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > nmsThreadsPerBlock) {
tgSize = nmsThreadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
int64_t num_to_keep = 0;
at::Tensor mask_cpu = mask.to(at::kCPU);
unsigned long long* mask_host = (unsigned long long*)mask_cpu.data_ptr<int64_t>();
std::vector<unsigned long long> remv(col_blocks);
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
at::Tensor keep = at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));
int64_t* keep_out = keep.data_ptr<int64_t>();
for (int64_t i = 0; i < dets_num; i++) {
int64_t nblock = i / nmsThreadsPerBlock;
int64_t inblock = i % nmsThreadsPerBlock;
if (!(remv[nblock] & (1ULL << inblock))) {
keep_out[num_to_keep++] = i;
unsigned long long* p = mask_host + i * col_blocks;
for (int64_t j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
}
}
return order_t.index(
{keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(order_t.device(), keep.scalar_type())});
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, MPS, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
}
} // namespace ops
} // namespace vision
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/OperationUtils.h>
#include "mps_helpers.h"
#include "mps_kernels.h"
namespace vision {
namespace ops {
namespace {
std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_kernel(const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio) {
using namespace at::native::mps;
TORCH_CHECK(input.is_mps(), "input must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]");
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "ps_roi_align_forward_kernel";
at::checkAllSameGPU(c, {input_t, rois_t});
at::checkAllSameType(c, {input_t, rois_t});
int64_t num_rois = rois.size(0);
int64_t channels = input.size(1);
int64_t height = input.size(2);
int64_t width = input.size(3);
float spatial_scale_f = static_cast<float>(spatial_scale);
TORCH_CHECK(channels % (pooled_height * pooled_width) == 0,
"input channels must be a multiple of pooling height * pooling width");
int64_t channels_out = channels / (pooled_height * pooled_width);
auto output = at::zeros({num_rois, channels_out, pooled_height, pooled_width}, input.options());
auto channel_mapping = at::zeros(output.sizes(), input.options().dtype(at::kLong));
int64_t output_size = output.numel();
if (output_size == 0) {
return std::make_tuple(output, channel_mapping);
}
auto input_ = input.contiguous();
auto rois_ = rois.contiguous();
id<MTLBuffer> inputBuffer = getMTLBufferStorage(input_);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(output);
id<MTLBuffer> channelMappingBuffer = getMTLBufferStorage(channel_mapping);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);
const std::string kernel = "ps_roi_align_" + scalarToMetalTypeString(input.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_});
[computeEncoder setComputePipelineState:visionPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2];
[computeEncoder setBuffer:channelMappingBuffer
offset:channel_mapping.storage_offset() * channel_mapping.element_size()
atIndex:3];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:10];
[computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:11];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:12];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
return std::make_tuple(output, channel_mapping);
}
at::Tensor ps_roi_align_backward_kernel(const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
using namespace at::native::mps;
TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support ps_roi_align backward with float16 inputs.");
TORCH_CHECK(channel_mapping.is_mps(), "channel_mapping must be a MPS tensor");
at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2},
channel_mapping_t{channel_mapping, "channel_mapping", 3};
at::CheckedFrom c = "ps_roi_align_backward_kernel";
at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t});
at::checkAllSameType(c, {grad_t, rois_t});
float spatial_scale_f = static_cast<float>(spatial_scale);
auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
if (grad.numel() == 0) {
return grad_input;
}
int64_t output_size = grad.numel();
int64_t channels_out = channels / (pooled_height * pooled_width);
at::globalContext().alertNotDeterministic("ps_roi_align_backward_kernel");
auto grad_ = grad.contiguous(), rois_ = rois.contiguous();
id<MTLBuffer> inputBuffer = getMTLBufferStorage(grad_);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> channelMappingBuffer = getMTLBufferStorage(channel_mapping);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(grad_input);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);
const std::string kernel = "ps_roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_});
[computeEncoder setComputePipelineState:visionPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:grad_.storage_offset() * grad_.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:channelMappingBuffer
offset:channel_mapping.storage_offset() * channel_mapping.element_size()
atIndex:2];
[computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:10];
[computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:11];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:12];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
return grad_input;
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, MPS, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), TORCH_FN(ps_roi_align_forward_kernel));
m.impl(TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), TORCH_FN(ps_roi_align_backward_kernel));
}
} // namespace ops
} // namespace vision
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/OperationUtils.h>
#include "mps_helpers.h"
#include "mps_kernels.h"
namespace vision {
namespace ops {
namespace {
std::tuple<at::Tensor, at::Tensor> ps_roi_pool_forward_kernel(const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width) {
using namespace at::native::mps;
TORCH_CHECK(input.is_mps(), "input must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]");
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "ps_roi_pool_forward_kernel";
at::checkAllSameGPU(c, {input_t, rois_t});
at::checkAllSameType(c, {input_t, rois_t});
int64_t num_rois = rois.size(0);
int64_t channels = input.size(1);
int64_t height = input.size(2);
int64_t width = input.size(3);
float spatial_scale_f = static_cast<float>(spatial_scale);
TORCH_CHECK(channels % (pooled_height * pooled_width) == 0,
"input channels must be a multiple of pooling height * pooling width");
int64_t channels_out = channels / (pooled_height * pooled_width);
auto output = at::zeros({num_rois, channels_out, pooled_height, pooled_width}, input.options());
auto channel_mapping = at::zeros(output.sizes(), input.options().dtype(at::kLong));
auto output_size = output.numel();
if (output_size == 0) {
return std::make_tuple(output, channel_mapping);
}
auto input_ = input.contiguous();
auto rois_ = rois.contiguous();
id<MTLBuffer> inputBuffer = getMTLBufferStorage(input_);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(output);
id<MTLBuffer> channelMappingBuffer = getMTLBufferStorage(channel_mapping);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);
const std::string kernel = "ps_roi_pool_" + scalarToMetalTypeString(input.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_});
[computeEncoder setComputePipelineState:visionPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2];
[computeEncoder setBuffer:channelMappingBuffer
offset:channel_mapping.storage_offset() * channel_mapping.element_size()
atIndex:3];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:10];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
return std::make_tuple(output, channel_mapping);
}
at::Tensor ps_roi_pool_backward_kernel(const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
using namespace at::native::mps;
TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support ps_roi_pool backward with float16 inputs.");
TORCH_CHECK(channel_mapping.is_mps(), "channel_mapping must be a MPS tensor");
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2},
channel_mapping_t{channel_mapping, "channel_mapping", 3};
at::CheckedFrom c = "ps_roi_pool_backward_kernel";
at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t});
at::checkAllSameType(c, {grad_t, rois_t});
float spatial_scale_f = static_cast<float>(spatial_scale);
auto num_rois = rois.size(0);
auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
if (grad.numel() == 0) {
return grad_input;
}
int64_t channels_out = channels / (pooled_height * pooled_width);
int64_t output_size = grad.numel();
at::globalContext().alertNotDeterministic("ps_roi_pool_backward_kernel");
auto grad_ = grad.contiguous(), rois_ = rois.contiguous();
id<MTLBuffer> inputBuffer = getMTLBufferStorage(grad_);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> channelMappingBuffer = getMTLBufferStorage(channel_mapping);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(grad_input);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);
const std::string kernel = "ps_roi_pool_backward_" + scalarToMetalTypeString(grad.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad_, rois_, channel_mapping});
[computeEncoder setComputePipelineState:visionPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:grad_.storage_offset() * grad_.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:channelMappingBuffer
offset:channel_mapping.storage_offset() * channel_mapping.element_size()
atIndex:2];
[computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:10];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
return grad_input;
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, MPS, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), TORCH_FN(ps_roi_pool_forward_kernel));
m.impl(TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"), TORCH_FN(ps_roi_pool_backward_kernel));
}
} // namespace ops
} // namespace vision
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/OperationUtils.h>
#include "mps_helpers.h"
#include "mps_kernels.h"
namespace vision {
namespace ops {
namespace {
at::Tensor roi_align_forward_kernel(const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
bool aligned) {
using namespace at::native::mps;
TORCH_CHECK(input.is_mps(), "input must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]");
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "roi_align_forward_kernel";
at::checkAllSameGPU(c, {input_t, rois_t});
at::checkAllSameType(c, {input_t, rois_t});
int64_t num_rois = rois.size(0);
int64_t channels = input.size(1);
int64_t height = input.size(2);
int64_t width = input.size(3);
float spatial_scale_f = static_cast<float>(spatial_scale);
at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options());
int64_t output_size = num_rois * pooled_height * pooled_width * channels;
if (output.numel() == 0) {
return output;
}
auto input_ = input.contiguous();
auto rois_ = rois.contiguous();
id<MTLBuffer> inputBuffer = getMTLBufferStorage(input_);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(output);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);
const std::string kernel = "roi_align_" + scalarToMetalTypeString(input.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_});
[computeEncoder setComputePipelineState:visionPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
return output;
}
at::Tensor roi_align_backward_kernel(const at::Tensor& grad,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width,
int64_t sampling_ratio,
bool aligned) {
using namespace at::native::mps;
TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support roi_align backward with float16 inputs.");
at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "roi_align_backward_kernel";
at::checkAllSameGPU(c, {grad_t, rois_t});
at::checkAllSameType(c, {grad_t, rois_t});
float spatial_scale_f = static_cast<float>(spatial_scale);
at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
if (grad.numel() == 0) {
return grad_input;
}
int64_t n_stride = grad.stride(0);
int64_t c_stride = grad.stride(1);
int64_t h_stride = grad.stride(2);
int64_t w_stride = grad.stride(3);
int64_t output_size = grad.numel();
at::globalContext().alertNotDeterministic("roi_align_backward_kernel");
auto rois_ = rois.contiguous();
id<MTLBuffer> inputBuffer = getMTLBufferStorage(grad);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(grad_input);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);
const std::string kernel = "roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_});
[computeEncoder setComputePipelineState:visionPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:2];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11];
[computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:12];
[computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:13];
[computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:14];
[computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:15];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
return grad_input;
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, MPS, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::roi_align"), TORCH_FN(roi_align_forward_kernel));
m.impl(TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), TORCH_FN(roi_align_backward_kernel));
}
} // namespace ops
} // namespace vision
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/OperationUtils.h>
#include "mps_helpers.h"
#include "mps_kernels.h"
namespace vision {
namespace ops {
namespace {
std::tuple<at::Tensor, at::Tensor> roi_pool_forward_kernel(const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width) {
using namespace at::native::mps;
TORCH_CHECK(input.is_mps(), "input must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]");
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "roi_pool_forward_kernel";
at::checkAllSameGPU(c, {input_t, rois_t});
at::checkAllSameType(c, {input_t, rois_t});
int64_t num_rois = rois.size(0);
int64_t channels = input.size(1);
int64_t height = input.size(2);
int64_t width = input.size(3);
float spatial_scale_f = static_cast<float>(spatial_scale);
at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options());
at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kLong));
int64_t output_size = num_rois * pooled_height * pooled_width * channels;
if (output.numel() == 0) {
return std::make_tuple(output, argmax);
}
auto input_ = input.contiguous();
auto rois_ = rois.contiguous();
id<MTLBuffer> inputBuffer = getMTLBufferStorage(input_);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(output);
id<MTLBuffer> argmaxBuffer = getMTLBufferStorage(argmax);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);
const std::string kernel = "roi_pool_" + scalarToMetalTypeString(input.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_});
[computeEncoder setComputePipelineState:visionPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2];
[computeEncoder setBuffer:argmaxBuffer offset:argmax.storage_offset() * argmax.element_size() atIndex:3];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:10];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
return std::make_tuple(output, argmax);
}
at::Tensor roi_pool_backward_kernel(const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
using namespace at::native::mps;
TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support roi_pool backward with float16 inputs.");
TORCH_CHECK(argmax.is_mps(), "argmax must be a MPS tensor");
at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, argmax_t{argmax, "argmax", 3};
at::CheckedFrom c = "roi_pool_backward_kernel";
at::checkAllSameGPU(c, {grad_t, rois_t, argmax_t});
at::checkAllSameType(c, {grad_t, rois_t});
float spatial_scale_f = static_cast<float>(spatial_scale);
at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
if (grad.numel() == 0) {
return grad_input;
}
int64_t n_stride = grad.stride(0);
int64_t c_stride = grad.stride(1);
int64_t h_stride = grad.stride(2);
int64_t w_stride = grad.stride(3);
int64_t output_size = grad.numel();
at::globalContext().alertNotDeterministic("roi_pool_backward_kernel");
auto argmax_ = argmax.contiguous(), rois_ = rois.contiguous();
id<MTLBuffer> inputBuffer = getMTLBufferStorage(grad);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> argmaxBuffer = getMTLBufferStorage(argmax_);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(grad_input);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);
const std::string kernel = "roi_pool_backward_" + scalarToMetalTypeString(grad.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_, argmax_});
[computeEncoder setComputePipelineState:visionPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:argmaxBuffer offset:argmax_.storage_offset() * argmax_.element_size() atIndex:2];
[computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:10];
[computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:11];
[computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:12];
[computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:13];
[computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:14];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
return grad_input;
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, MPS, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::roi_pool"), TORCH_FN(roi_pool_forward_kernel));
m.impl(TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), TORCH_FN(roi_pool_backward_kernel));
}
} // namespace ops
} // namespace vision
...@@ -164,7 +164,7 @@ void qroi_align_forward_kernel_impl( ...@@ -164,7 +164,7 @@ void qroi_align_forward_kernel_impl(
const float count = const float count =
std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
// we want to precalculate indices and weights shared by all chanels, // we want to precalculate indices and weights shared by all channels,
// this is the key point of optimization // this is the key point of optimization
std::vector<detail::PreCalc<float>> pre_calc( std::vector<detail::PreCalc<float>> pre_calc(
roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height); roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
......
...@@ -32,6 +32,31 @@ at::Tensor roi_align( ...@@ -32,6 +32,31 @@ at::Tensor roi_align(
aligned); aligned);
} }
at::Tensor roi_align_symint(
const at::Tensor& input, // Input feature map.
const at::Tensor& rois, // List of ROIs to pool over.
double spatial_scale, // The scale of the image features. ROIs will be
// scaled to this.
c10::SymInt pooled_height, // The height of the pooled feature map.
c10::SymInt pooled_width, // The width of the pooled feature
int64_t sampling_ratio, // The number of points to sample in each bin
bool aligned) // The flag for pixel shift
// along each axis.
{
C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.roi_align.roi_align");
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::roi_align", "")
.typed<decltype(roi_align_symint)>();
return op.call(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
aligned);
}
namespace detail { namespace detail {
at::Tensor _roi_align_backward( at::Tensor _roi_align_backward(
...@@ -64,13 +89,43 @@ at::Tensor _roi_align_backward( ...@@ -64,13 +89,43 @@ at::Tensor _roi_align_backward(
aligned); aligned);
} }
at::Tensor _roi_align_backward_symint(
const at::Tensor& grad,
const at::Tensor& rois,
double spatial_scale,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
c10::SymInt batch_size,
c10::SymInt channels,
c10::SymInt height,
c10::SymInt width,
int64_t sampling_ratio,
bool aligned) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_roi_align_backward", "")
.typed<decltype(_roi_align_backward_symint)>();
return op.call(
grad,
rois,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width,
sampling_ratio,
aligned);
}
} // namespace detail } // namespace detail
TORCH_LIBRARY_FRAGMENT(torchvision, m) { TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(TORCH_SELECTIVE_SCHEMA( m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor")); "torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, bool aligned) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA( m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor")); "torchvision::_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width, int sampling_ratio, bool aligned) -> Tensor"));
} }
} // namespace ops } // namespace ops
......
...@@ -15,6 +15,15 @@ VISION_API at::Tensor roi_align( ...@@ -15,6 +15,15 @@ VISION_API at::Tensor roi_align(
int64_t sampling_ratio, int64_t sampling_ratio,
bool aligned); bool aligned);
VISION_API at::Tensor roi_align_symint(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
int64_t sampling_ratio,
bool aligned);
namespace detail { namespace detail {
at::Tensor _roi_align_backward( at::Tensor _roi_align_backward(
...@@ -30,6 +39,19 @@ at::Tensor _roi_align_backward( ...@@ -30,6 +39,19 @@ at::Tensor _roi_align_backward(
int64_t sampling_ratio, int64_t sampling_ratio,
bool aligned); bool aligned);
at::Tensor _roi_align_backward_symint(
const at::Tensor& grad,
const at::Tensor& rois,
double spatial_scale,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
c10::SymInt batch_size,
c10::SymInt channels,
c10::SymInt height,
c10::SymInt width,
int64_t sampling_ratio,
bool aligned);
} // namespace detail } // namespace detail
} // namespace ops } // namespace ops
......
...@@ -36,6 +36,7 @@ from .kitti import Kitti ...@@ -36,6 +36,7 @@ from .kitti import Kitti
from .lfw import LFWPairs, LFWPeople from .lfw import LFWPairs, LFWPeople
from .lsun import LSUN, LSUNClass from .lsun import LSUN, LSUNClass
from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST
from .moving_mnist import MovingMNIST
from .omniglot import Omniglot from .omniglot import Omniglot
from .oxford_iiit_pet import OxfordIIITPet from .oxford_iiit_pet import OxfordIIITPet
from .pcam import PCAM from .pcam import PCAM
...@@ -126,4 +127,18 @@ __all__ = ( ...@@ -126,4 +127,18 @@ __all__ = (
"SintelStereo", "SintelStereo",
"InStereo2k", "InStereo2k",
"ETH3DStereo", "ETH3DStereo",
"wrap_dataset_for_transforms_v2",
) )
# We override current module's attributes to handle the import:
# from torchvision.datasets import wrap_dataset_for_transforms_v2
# without a cyclic error.
# Ref: https://peps.python.org/pep-0562/
def __getattr__(name):
if name in ("wrap_dataset_for_transforms_v2",):
from torchvision.tv_tensors._dataset_wrapper import wrap_dataset_for_transforms_v2
return wrap_dataset_for_transforms_v2
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from glob import glob from glob import glob
from pathlib import Path from pathlib import Path
from typing import Callable, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -13,6 +14,10 @@ from .utils import _read_pfm, verify_str_arg ...@@ -13,6 +14,10 @@ from .utils import _read_pfm, verify_str_arg
from .vision import VisionDataset from .vision import VisionDataset
T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], Optional[np.ndarray]]
T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]
__all__ = ( __all__ = (
"KittiFlow", "KittiFlow",
"Sintel", "Sintel",
...@@ -28,26 +33,26 @@ class FlowDataset(ABC, VisionDataset): ...@@ -28,26 +33,26 @@ class FlowDataset(ABC, VisionDataset):
# and it's up to whatever consumes the dataset to decide what valid_flow_mask should be. # and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
_has_builtin_flow_mask = False _has_builtin_flow_mask = False
def __init__(self, root, transforms=None): def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
super().__init__(root=root) super().__init__(root=root)
self.transforms = transforms self.transforms = transforms
self._flow_list = [] self._flow_list: List[str] = []
self._image_list = [] self._image_list: List[List[str]] = []
def _read_img(self, file_name): def _read_img(self, file_name: str) -> Image.Image:
img = Image.open(file_name) img = Image.open(file_name)
if img.mode != "RGB": if img.mode != "RGB":
img = img.convert("RGB") img = img.convert("RGB")
return img return img
@abstractmethod @abstractmethod
def _read_flow(self, file_name): def _read_flow(self, file_name: str):
# Return the flow or a tuple with the flow and the valid_flow_mask if _has_builtin_flow_mask is True # Return the flow or a tuple with the flow and the valid_flow_mask if _has_builtin_flow_mask is True
pass pass
def __getitem__(self, index): def __getitem__(self, index: int) -> Union[T1, T2]:
img1 = self._read_img(self._image_list[index][0]) img1 = self._read_img(self._image_list[index][0])
img2 = self._read_img(self._image_list[index][1]) img2 = self._read_img(self._image_list[index][1])
...@@ -70,10 +75,10 @@ class FlowDataset(ABC, VisionDataset): ...@@ -70,10 +75,10 @@ class FlowDataset(ABC, VisionDataset):
else: else:
return img1, img2, flow return img1, img2, flow
def __len__(self): def __len__(self) -> int:
return len(self._image_list) return len(self._image_list)
def __rmul__(self, v): def __rmul__(self, v: int) -> torch.utils.data.ConcatDataset:
return torch.utils.data.ConcatDataset([self] * v) return torch.utils.data.ConcatDataset([self] * v)
...@@ -118,7 +123,13 @@ class Sintel(FlowDataset): ...@@ -118,7 +123,13 @@ class Sintel(FlowDataset):
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
""" """
def __init__(self, root, split="train", pass_name="clean", transforms=None): def __init__(
self,
root: str,
split: str = "train",
pass_name: str = "clean",
transforms: Optional[Callable] = None,
) -> None:
super().__init__(root=root, transforms=transforms) super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "test")) verify_str_arg(split, "split", valid_values=("train", "test"))
...@@ -139,7 +150,7 @@ class Sintel(FlowDataset): ...@@ -139,7 +150,7 @@ class Sintel(FlowDataset):
if split == "train": if split == "train":
self._flow_list += sorted(glob(str(flow_root / scene / "*.flo"))) self._flow_list += sorted(glob(str(flow_root / scene / "*.flo")))
def __getitem__(self, index): def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -154,7 +165,7 @@ class Sintel(FlowDataset): ...@@ -154,7 +165,7 @@ class Sintel(FlowDataset):
""" """
return super().__getitem__(index) return super().__getitem__(index)
def _read_flow(self, file_name): def _read_flow(self, file_name: str) -> np.ndarray:
return _read_flo(file_name) return _read_flo(file_name)
...@@ -180,7 +191,7 @@ class KittiFlow(FlowDataset): ...@@ -180,7 +191,7 @@ class KittiFlow(FlowDataset):
_has_builtin_flow_mask = True _has_builtin_flow_mask = True
def __init__(self, root, split="train", transforms=None): def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root=root, transforms=transforms) super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "test")) verify_str_arg(split, "split", valid_values=("train", "test"))
...@@ -200,7 +211,7 @@ class KittiFlow(FlowDataset): ...@@ -200,7 +211,7 @@ class KittiFlow(FlowDataset):
if split == "train": if split == "train":
self._flow_list = sorted(glob(str(root / "flow_occ" / "*_10.png"))) self._flow_list = sorted(glob(str(root / "flow_occ" / "*_10.png")))
def __getitem__(self, index): def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -215,7 +226,7 @@ class KittiFlow(FlowDataset): ...@@ -215,7 +226,7 @@ class KittiFlow(FlowDataset):
""" """
return super().__getitem__(index) return super().__getitem__(index)
def _read_flow(self, file_name): def _read_flow(self, file_name: str) -> Tuple[np.ndarray, np.ndarray]:
return _read_16bits_png_with_flow_and_valid_mask(file_name) return _read_16bits_png_with_flow_and_valid_mask(file_name)
...@@ -245,7 +256,7 @@ class FlyingChairs(FlowDataset): ...@@ -245,7 +256,7 @@ class FlyingChairs(FlowDataset):
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
""" """
def __init__(self, root, split="train", transforms=None): def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root=root, transforms=transforms) super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "val")) verify_str_arg(split, "split", valid_values=("train", "val"))
...@@ -268,7 +279,7 @@ class FlyingChairs(FlowDataset): ...@@ -268,7 +279,7 @@ class FlyingChairs(FlowDataset):
self._flow_list += [flows[i]] self._flow_list += [flows[i]]
self._image_list += [[images[2 * i], images[2 * i + 1]]] self._image_list += [[images[2 * i], images[2 * i + 1]]]
def __getitem__(self, index): def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -283,7 +294,7 @@ class FlyingChairs(FlowDataset): ...@@ -283,7 +294,7 @@ class FlyingChairs(FlowDataset):
""" """
return super().__getitem__(index) return super().__getitem__(index)
def _read_flow(self, file_name): def _read_flow(self, file_name: str) -> np.ndarray:
return _read_flo(file_name) return _read_flo(file_name)
...@@ -316,7 +327,14 @@ class FlyingThings3D(FlowDataset): ...@@ -316,7 +327,14 @@ class FlyingThings3D(FlowDataset):
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
""" """
def __init__(self, root, split="train", pass_name="clean", camera="left", transforms=None): def __init__(
self,
root: str,
split: str = "train",
pass_name: str = "clean",
camera: str = "left",
transforms: Optional[Callable] = None,
) -> None:
super().__init__(root=root, transforms=transforms) super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "test")) verify_str_arg(split, "split", valid_values=("train", "test"))
...@@ -359,7 +377,7 @@ class FlyingThings3D(FlowDataset): ...@@ -359,7 +377,7 @@ class FlyingThings3D(FlowDataset):
self._image_list += [[images[i + 1], images[i]]] self._image_list += [[images[i + 1], images[i]]]
self._flow_list += [flows[i + 1]] self._flow_list += [flows[i + 1]]
def __getitem__(self, index): def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -374,7 +392,7 @@ class FlyingThings3D(FlowDataset): ...@@ -374,7 +392,7 @@ class FlyingThings3D(FlowDataset):
""" """
return super().__getitem__(index) return super().__getitem__(index)
def _read_flow(self, file_name): def _read_flow(self, file_name: str) -> np.ndarray:
return _read_pfm(file_name) return _read_pfm(file_name)
...@@ -401,7 +419,7 @@ class HD1K(FlowDataset): ...@@ -401,7 +419,7 @@ class HD1K(FlowDataset):
_has_builtin_flow_mask = True _has_builtin_flow_mask = True
def __init__(self, root, split="train", transforms=None): def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root=root, transforms=transforms) super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "test")) verify_str_arg(split, "split", valid_values=("train", "test"))
...@@ -426,10 +444,10 @@ class HD1K(FlowDataset): ...@@ -426,10 +444,10 @@ class HD1K(FlowDataset):
"Could not find the HD1K images. Please make sure the directory structure is correct." "Could not find the HD1K images. Please make sure the directory structure is correct."
) )
def _read_flow(self, file_name): def _read_flow(self, file_name: str) -> Tuple[np.ndarray, np.ndarray]:
return _read_16bits_png_with_flow_and_valid_mask(file_name) return _read_16bits_png_with_flow_and_valid_mask(file_name)
def __getitem__(self, index): def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -445,7 +463,7 @@ class HD1K(FlowDataset): ...@@ -445,7 +463,7 @@ class HD1K(FlowDataset):
return super().__getitem__(index) return super().__getitem__(index)
def _read_flo(file_name): def _read_flo(file_name: str) -> np.ndarray:
"""Read .flo file in Middlebury format""" """Read .flo file in Middlebury format"""
# Code adapted from: # Code adapted from:
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
...@@ -462,7 +480,7 @@ def _read_flo(file_name): ...@@ -462,7 +480,7 @@ def _read_flo(file_name):
return data.reshape(h, w, 2).transpose(2, 0, 1) return data.reshape(h, w, 2).transpose(2, 0, 1)
def _read_16bits_png_with_flow_and_valid_mask(file_name): def _read_16bits_png_with_flow_and_valid_mask(file_name: str) -> Tuple[np.ndarray, np.ndarray]:
flow_and_valid = _read_png_16(file_name).to(torch.float32) flow_and_valid = _read_png_16(file_name).to(torch.float32)
flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :] flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :]
......
...@@ -6,7 +6,7 @@ import shutil ...@@ -6,7 +6,7 @@ import shutil
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from glob import glob from glob import glob
from pathlib import Path from pathlib import Path
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, cast, List, Optional, Tuple, Union
import numpy as np import numpy as np
from PIL import Image from PIL import Image
...@@ -14,6 +14,9 @@ from PIL import Image ...@@ -14,6 +14,9 @@ from PIL import Image
from .utils import _read_pfm, download_and_extract_archive, verify_str_arg from .utils import _read_pfm, download_and_extract_archive, verify_str_arg
from .vision import VisionDataset from .vision import VisionDataset
T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], np.ndarray]
T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]
__all__ = () __all__ = ()
_read_pfm_file = functools.partial(_read_pfm, slice_channels=1) _read_pfm_file = functools.partial(_read_pfm, slice_channels=1)
...@@ -24,7 +27,7 @@ class StereoMatchingDataset(ABC, VisionDataset): ...@@ -24,7 +27,7 @@ class StereoMatchingDataset(ABC, VisionDataset):
_has_built_in_disparity_mask = False _has_built_in_disparity_mask = False
def __init__(self, root: str, transforms: Optional[Callable] = None): def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
""" """
Args: Args:
root(str): Root directory of the dataset. root(str): Root directory of the dataset.
...@@ -58,7 +61,11 @@ class StereoMatchingDataset(ABC, VisionDataset): ...@@ -58,7 +61,11 @@ class StereoMatchingDataset(ABC, VisionDataset):
img = img.convert("RGB") img = img.convert("RGB")
return img return img
def _scan_pairs(self, paths_left_pattern: str, paths_right_pattern: Optional[str] = None): def _scan_pairs(
self,
paths_left_pattern: str,
paths_right_pattern: Optional[str] = None,
) -> List[Tuple[str, Optional[str]]]:
left_paths = list(sorted(glob(paths_left_pattern))) left_paths = list(sorted(glob(paths_left_pattern)))
...@@ -85,11 +92,11 @@ class StereoMatchingDataset(ABC, VisionDataset): ...@@ -85,11 +92,11 @@ class StereoMatchingDataset(ABC, VisionDataset):
return paths return paths
@abstractmethod @abstractmethod
def _read_disparity(self, file_path: str) -> Tuple: def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
# function that returns a disparity map and an occlusion map # function that returns a disparity map and an occlusion map
pass pass
def __getitem__(self, index: int) -> Tuple: def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -120,7 +127,7 @@ class StereoMatchingDataset(ABC, VisionDataset): ...@@ -120,7 +127,7 @@ class StereoMatchingDataset(ABC, VisionDataset):
) = self.transforms(imgs, dsp_maps, valid_masks) ) = self.transforms(imgs, dsp_maps, valid_masks)
if self._has_built_in_disparity_mask or valid_masks[0] is not None: if self._has_built_in_disparity_mask or valid_masks[0] is not None:
return imgs[0], imgs[1], dsp_maps[0], valid_masks[0] return imgs[0], imgs[1], dsp_maps[0], cast(np.ndarray, valid_masks[0])
else: else:
return imgs[0], imgs[1], dsp_maps[0] return imgs[0], imgs[1], dsp_maps[0]
...@@ -156,7 +163,7 @@ class CarlaStereo(StereoMatchingDataset): ...@@ -156,7 +163,7 @@ class CarlaStereo(StereoMatchingDataset):
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
""" """
def __init__(self, root: str, transforms: Optional[Callable] = None): def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
root = Path(root) / "carla-highres" root = Path(root) / "carla-highres"
...@@ -171,13 +178,13 @@ class CarlaStereo(StereoMatchingDataset): ...@@ -171,13 +178,13 @@ class CarlaStereo(StereoMatchingDataset):
disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern) disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
self._disparities = disparities self._disparities = disparities
def _read_disparity(self, file_path: str) -> Tuple: def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
disparity_map = _read_pfm_file(file_path) disparity_map = _read_pfm_file(file_path)
disparity_map = np.abs(disparity_map) # ensure that the disparity is positive disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
valid_mask = None valid_mask = None
return disparity_map, valid_mask return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple: def __getitem__(self, index: int) -> T1:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -189,7 +196,7 @@ class CarlaStereo(StereoMatchingDataset): ...@@ -189,7 +196,7 @@ class CarlaStereo(StereoMatchingDataset):
If a ``valid_mask`` is generated within the ``transforms`` parameter, If a ``valid_mask`` is generated within the ``transforms`` parameter,
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
""" """
return super().__getitem__(index) return cast(T1, super().__getitem__(index))
class Kitti2012Stereo(StereoMatchingDataset): class Kitti2012Stereo(StereoMatchingDataset):
...@@ -233,7 +240,7 @@ class Kitti2012Stereo(StereoMatchingDataset): ...@@ -233,7 +240,7 @@ class Kitti2012Stereo(StereoMatchingDataset):
_has_built_in_disparity_mask = True _has_built_in_disparity_mask = True
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None): def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
verify_str_arg(split, "split", valid_values=("train", "test")) verify_str_arg(split, "split", valid_values=("train", "test"))
...@@ -250,7 +257,7 @@ class Kitti2012Stereo(StereoMatchingDataset): ...@@ -250,7 +257,7 @@ class Kitti2012Stereo(StereoMatchingDataset):
else: else:
self._disparities = list((None, None) for _ in self._images) self._disparities = list((None, None) for _ in self._images)
def _read_disparity(self, file_path: str) -> Tuple: def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], None]:
# test split has no disparity maps # test split has no disparity maps
if file_path is None: if file_path is None:
return None, None return None, None
...@@ -261,7 +268,7 @@ class Kitti2012Stereo(StereoMatchingDataset): ...@@ -261,7 +268,7 @@ class Kitti2012Stereo(StereoMatchingDataset):
valid_mask = None valid_mask = None
return disparity_map, valid_mask return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple: def __getitem__(self, index: int) -> T1:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -274,7 +281,7 @@ class Kitti2012Stereo(StereoMatchingDataset): ...@@ -274,7 +281,7 @@ class Kitti2012Stereo(StereoMatchingDataset):
generate a valid mask. generate a valid mask.
Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test. Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
""" """
return super().__getitem__(index) return cast(T1, super().__getitem__(index))
class Kitti2015Stereo(StereoMatchingDataset): class Kitti2015Stereo(StereoMatchingDataset):
...@@ -321,7 +328,7 @@ class Kitti2015Stereo(StereoMatchingDataset): ...@@ -321,7 +328,7 @@ class Kitti2015Stereo(StereoMatchingDataset):
_has_built_in_disparity_mask = True _has_built_in_disparity_mask = True
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None): def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
verify_str_arg(split, "split", valid_values=("train", "test")) verify_str_arg(split, "split", valid_values=("train", "test"))
...@@ -338,7 +345,7 @@ class Kitti2015Stereo(StereoMatchingDataset): ...@@ -338,7 +345,7 @@ class Kitti2015Stereo(StereoMatchingDataset):
else: else:
self._disparities = list((None, None) for _ in self._images) self._disparities = list((None, None) for _ in self._images)
def _read_disparity(self, file_path: str) -> Tuple: def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], None]:
# test split has no disparity maps # test split has no disparity maps
if file_path is None: if file_path is None:
return None, None return None, None
...@@ -349,7 +356,7 @@ class Kitti2015Stereo(StereoMatchingDataset): ...@@ -349,7 +356,7 @@ class Kitti2015Stereo(StereoMatchingDataset):
valid_mask = None valid_mask = None
return disparity_map, valid_mask return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple: def __getitem__(self, index: int) -> T1:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -362,7 +369,7 @@ class Kitti2015Stereo(StereoMatchingDataset): ...@@ -362,7 +369,7 @@ class Kitti2015Stereo(StereoMatchingDataset):
generate a valid mask. generate a valid mask.
Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test. Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
""" """
return super().__getitem__(index) return cast(T1, super().__getitem__(index))
class Middlebury2014Stereo(StereoMatchingDataset): class Middlebury2014Stereo(StereoMatchingDataset):
...@@ -417,9 +424,9 @@ class Middlebury2014Stereo(StereoMatchingDataset): ...@@ -417,9 +424,9 @@ class Middlebury2014Stereo(StereoMatchingDataset):
split (string, optional): The dataset split of scenes, either "train" (default), "test", or "additional" split (string, optional): The dataset split of scenes, either "train" (default), "test", or "additional"
use_ambient_views (boolean, optional): Whether to use different expose or lightning views when possible. use_ambient_views (boolean, optional): Whether to use different expose or lightning views when possible.
The dataset samples with equal probability between ``[im1.png, im1E.png, im1L.png]``. The dataset samples with equal probability between ``[im1.png, im1E.png, im1L.png]``.
calibration (string, optional): Wether or not to use the calibrated (default) or uncalibrated scenes. calibration (string, optional): Whether or not to use the calibrated (default) or uncalibrated scenes.
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
download (boolean, optional): Wether or not to download the dataset in the ``root`` directory. download (boolean, optional): Whether or not to download the dataset in the ``root`` directory.
""" """
splits = { splits = {
...@@ -479,7 +486,7 @@ class Middlebury2014Stereo(StereoMatchingDataset): ...@@ -479,7 +486,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
use_ambient_views: bool = False, use_ambient_views: bool = False,
transforms: Optional[Callable] = None, transforms: Optional[Callable] = None,
download: bool = False, download: bool = False,
): ) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
verify_str_arg(split, "split", valid_values=("train", "test", "additional")) verify_str_arg(split, "split", valid_values=("train", "test", "additional"))
...@@ -558,7 +565,7 @@ class Middlebury2014Stereo(StereoMatchingDataset): ...@@ -558,7 +565,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
file_path = random.choice(ambient_file_paths) # type: ignore file_path = random.choice(ambient_file_paths) # type: ignore
return super()._read_img(file_path) return super()._read_img(file_path)
def _read_disparity(self, file_path: str) -> Tuple: def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
# test split has not disparity maps # test split has not disparity maps
if file_path is None: if file_path is None:
return None, None return None, None
...@@ -569,7 +576,7 @@ class Middlebury2014Stereo(StereoMatchingDataset): ...@@ -569,7 +576,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
valid_mask = (disparity_map > 0).squeeze(0) # mask out invalid disparities valid_mask = (disparity_map > 0).squeeze(0) # mask out invalid disparities
return disparity_map, valid_mask return disparity_map, valid_mask
def _download_dataset(self, root: str): def _download_dataset(self, root: str) -> None:
base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip" base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip"
# train and additional splits have 2 different calibration settings # train and additional splits have 2 different calibration settings
root = Path(root) / "Middlebury2014" root = Path(root) / "Middlebury2014"
...@@ -608,7 +615,7 @@ class Middlebury2014Stereo(StereoMatchingDataset): ...@@ -608,7 +615,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
# cleanup MiddEval3 directory # cleanup MiddEval3 directory
shutil.rmtree(str(root / "MiddEval3")) shutil.rmtree(str(root / "MiddEval3"))
def __getitem__(self, index: int) -> Tuple: def __getitem__(self, index: int) -> T2:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -619,7 +626,7 @@ class Middlebury2014Stereo(StereoMatchingDataset): ...@@ -619,7 +626,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
The disparity is a numpy array of shape (1, H, W) and the images are PIL images. The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
``valid_mask`` is implicitly ``None`` for `split=test`. ``valid_mask`` is implicitly ``None`` for `split=test`.
""" """
return super().__getitem__(index) return cast(T2, super().__getitem__(index))
class CREStereo(StereoMatchingDataset): class CREStereo(StereoMatchingDataset):
...@@ -670,7 +677,7 @@ class CREStereo(StereoMatchingDataset): ...@@ -670,7 +677,7 @@ class CREStereo(StereoMatchingDataset):
self, self,
root: str, root: str,
transforms: Optional[Callable] = None, transforms: Optional[Callable] = None,
): ) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
root = Path(root) / "CREStereo" root = Path(root) / "CREStereo"
...@@ -688,14 +695,14 @@ class CREStereo(StereoMatchingDataset): ...@@ -688,14 +695,14 @@ class CREStereo(StereoMatchingDataset):
disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern) disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
self._disparities += disparities self._disparities += disparities
def _read_disparity(self, file_path: str) -> Tuple: def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
disparity_map = np.asarray(Image.open(file_path), dtype=np.float32) disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
# unsqueeze the disparity map into (C, H, W) format # unsqueeze the disparity map into (C, H, W) format
disparity_map = disparity_map[None, :, :] / 32.0 disparity_map = disparity_map[None, :, :] / 32.0
valid_mask = None valid_mask = None
return disparity_map, valid_mask return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple: def __getitem__(self, index: int) -> T1:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -707,13 +714,13 @@ class CREStereo(StereoMatchingDataset): ...@@ -707,13 +714,13 @@ class CREStereo(StereoMatchingDataset):
``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
generate a valid mask. generate a valid mask.
""" """
return super().__getitem__(index) return cast(T1, super().__getitem__(index))
class FallingThingsStereo(StereoMatchingDataset): class FallingThingsStereo(StereoMatchingDataset):
"""`FallingThings <https://research.nvidia.com/publication/2018-06_falling-things-synthetic-dataset-3d-object-detection-and-pose-estimation>`_ dataset. """`FallingThings <https://research.nvidia.com/publication/2018-06_falling-things-synthetic-dataset-3d-object-detection-and-pose-estimation>`_ dataset.
The dataset is expected to have the following structre: :: The dataset is expected to have the following structure: ::
root root
FallingThings FallingThings
...@@ -755,7 +762,7 @@ class FallingThingsStereo(StereoMatchingDataset): ...@@ -755,7 +762,7 @@ class FallingThingsStereo(StereoMatchingDataset):
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
""" """
def __init__(self, root: str, variant: str = "single", transforms: Optional[Callable] = None): def __init__(self, root: str, variant: str = "single", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
root = Path(root) / "FallingThings" root = Path(root) / "FallingThings"
...@@ -782,14 +789,14 @@ class FallingThingsStereo(StereoMatchingDataset): ...@@ -782,14 +789,14 @@ class FallingThingsStereo(StereoMatchingDataset):
right_disparity_pattern = str(root / s / split_prefix[s] / "*.right.depth.png") right_disparity_pattern = str(root / s / split_prefix[s] / "*.right.depth.png")
self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern) self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
def _read_disparity(self, file_path: str) -> Tuple: def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
# (H, W) image # (H, W) image
depth = np.asarray(Image.open(file_path)) depth = np.asarray(Image.open(file_path))
# as per https://research.nvidia.com/sites/default/files/pubs/2018-06_Falling-Things/readme_0.txt # as per https://research.nvidia.com/sites/default/files/pubs/2018-06_Falling-Things/readme_0.txt
# in order to extract disparity from depth maps # in order to extract disparity from depth maps
camera_settings_path = Path(file_path).parent / "_camera_settings.json" camera_settings_path = Path(file_path).parent / "_camera_settings.json"
with open(camera_settings_path, "r") as f: with open(camera_settings_path, "r") as f:
# inverse of depth-from-disparity equation: depth = (baseline * focal) / (disparity * pixel_constatnt) # inverse of depth-from-disparity equation: depth = (baseline * focal) / (disparity * pixel_constant)
intrinsics = json.load(f) intrinsics = json.load(f)
focal = intrinsics["camera_settings"][0]["intrinsic_settings"]["fx"] focal = intrinsics["camera_settings"][0]["intrinsic_settings"]["fx"]
baseline, pixel_constant = 6, 100 # pixel constant is inverted baseline, pixel_constant = 6, 100 # pixel constant is inverted
...@@ -799,7 +806,7 @@ class FallingThingsStereo(StereoMatchingDataset): ...@@ -799,7 +806,7 @@ class FallingThingsStereo(StereoMatchingDataset):
valid_mask = None valid_mask = None
return disparity_map, valid_mask return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple: def __getitem__(self, index: int) -> T1:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -811,14 +818,14 @@ class FallingThingsStereo(StereoMatchingDataset): ...@@ -811,14 +818,14 @@ class FallingThingsStereo(StereoMatchingDataset):
If a ``valid_mask`` is generated within the ``transforms`` parameter, If a ``valid_mask`` is generated within the ``transforms`` parameter,
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
""" """
return super().__getitem__(index) return cast(T1, super().__getitem__(index))
class SceneFlowStereo(StereoMatchingDataset): class SceneFlowStereo(StereoMatchingDataset):
"""Dataset interface for `Scene Flow <https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html>`_ datasets. """Dataset interface for `Scene Flow <https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html>`_ datasets.
This interface provides access to the `FlyingThings3D, `Monkaa` and `Driving` datasets. This interface provides access to the `FlyingThings3D, `Monkaa` and `Driving` datasets.
The dataset is expected to have the following structre: :: The dataset is expected to have the following structure: ::
root root
SceneFlow SceneFlow
...@@ -874,7 +881,7 @@ class SceneFlowStereo(StereoMatchingDataset): ...@@ -874,7 +881,7 @@ class SceneFlowStereo(StereoMatchingDataset):
variant: str = "FlyingThings3D", variant: str = "FlyingThings3D",
pass_name: str = "clean", pass_name: str = "clean",
transforms: Optional[Callable] = None, transforms: Optional[Callable] = None,
): ) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
root = Path(root) / "SceneFlow" root = Path(root) / "SceneFlow"
...@@ -905,13 +912,13 @@ class SceneFlowStereo(StereoMatchingDataset): ...@@ -905,13 +912,13 @@ class SceneFlowStereo(StereoMatchingDataset):
right_disparity_pattern = str(root / "disparity" / prefix_directories[variant] / "right" / "*.pfm") right_disparity_pattern = str(root / "disparity" / prefix_directories[variant] / "right" / "*.pfm")
self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern) self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
def _read_disparity(self, file_path: str) -> Tuple: def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
disparity_map = _read_pfm_file(file_path) disparity_map = _read_pfm_file(file_path)
disparity_map = np.abs(disparity_map) # ensure that the disparity is positive disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
valid_mask = None valid_mask = None
return disparity_map, valid_mask return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple: def __getitem__(self, index: int) -> T1:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -923,7 +930,7 @@ class SceneFlowStereo(StereoMatchingDataset): ...@@ -923,7 +930,7 @@ class SceneFlowStereo(StereoMatchingDataset):
If a ``valid_mask`` is generated within the ``transforms`` parameter, If a ``valid_mask`` is generated within the ``transforms`` parameter,
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
""" """
return super().__getitem__(index) return cast(T1, super().__getitem__(index))
class SintelStereo(StereoMatchingDataset): class SintelStereo(StereoMatchingDataset):
...@@ -973,7 +980,7 @@ class SintelStereo(StereoMatchingDataset): ...@@ -973,7 +980,7 @@ class SintelStereo(StereoMatchingDataset):
_has_built_in_disparity_mask = True _has_built_in_disparity_mask = True
def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None): def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both")) verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both"))
...@@ -1014,7 +1021,7 @@ class SintelStereo(StereoMatchingDataset): ...@@ -1014,7 +1021,7 @@ class SintelStereo(StereoMatchingDataset):
return occlusion_path, outofframe_path return occlusion_path, outofframe_path
def _read_disparity(self, file_path: str) -> Tuple: def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
if file_path is None: if file_path is None:
return None, None return None, None
...@@ -1024,7 +1031,7 @@ class SintelStereo(StereoMatchingDataset): ...@@ -1024,7 +1031,7 @@ class SintelStereo(StereoMatchingDataset):
disparity_map = r * 4 + g / (2**6) + b / (2**14) disparity_map = r * 4 + g / (2**6) + b / (2**14)
# reshape into (C, H, W) format # reshape into (C, H, W) format
disparity_map = np.transpose(disparity_map, (2, 0, 1)) disparity_map = np.transpose(disparity_map, (2, 0, 1))
# find the appropiate file paths # find the appropriate file paths
occlued_mask_path, out_of_frame_mask_path = self._get_occlussion_mask_paths(file_path) occlued_mask_path, out_of_frame_mask_path = self._get_occlussion_mask_paths(file_path)
# occlusion masks # occlusion masks
valid_mask = np.asarray(Image.open(occlued_mask_path)) == 0 valid_mask = np.asarray(Image.open(occlued_mask_path)) == 0
...@@ -1034,7 +1041,7 @@ class SintelStereo(StereoMatchingDataset): ...@@ -1034,7 +1041,7 @@ class SintelStereo(StereoMatchingDataset):
valid_mask = np.logical_and(off_mask, valid_mask) valid_mask = np.logical_and(off_mask, valid_mask)
return disparity_map, valid_mask return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple: def __getitem__(self, index: int) -> T2:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -1045,13 +1052,13 @@ class SintelStereo(StereoMatchingDataset): ...@@ -1045,13 +1052,13 @@ class SintelStereo(StereoMatchingDataset):
The disparity is a numpy array of shape (1, H, W) and the images are PIL images whilst The disparity is a numpy array of shape (1, H, W) and the images are PIL images whilst
the valid_mask is a numpy array of shape (H, W). the valid_mask is a numpy array of shape (H, W).
""" """
return super().__getitem__(index) return cast(T2, super().__getitem__(index))
class InStereo2k(StereoMatchingDataset): class InStereo2k(StereoMatchingDataset):
"""`InStereo2k <https://github.com/YuhuaXu/StereoDataset>`_ dataset. """`InStereo2k <https://github.com/YuhuaXu/StereoDataset>`_ dataset.
The dataset is expected to have the following structre: :: The dataset is expected to have the following structure: ::
root root
InStereo2k InStereo2k
...@@ -1080,7 +1087,7 @@ class InStereo2k(StereoMatchingDataset): ...@@ -1080,7 +1087,7 @@ class InStereo2k(StereoMatchingDataset):
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
""" """
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None): def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
root = Path(root) / "InStereo2k" / split root = Path(root) / "InStereo2k" / split
...@@ -1095,14 +1102,14 @@ class InStereo2k(StereoMatchingDataset): ...@@ -1095,14 +1102,14 @@ class InStereo2k(StereoMatchingDataset):
right_disparity_pattern = str(root / "*" / "right_disp.png") right_disparity_pattern = str(root / "*" / "right_disp.png")
self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern) self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
def _read_disparity(self, file_path: str) -> Tuple: def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
disparity_map = np.asarray(Image.open(file_path), dtype=np.float32) disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
# unsqueeze disparity to (C, H, W) # unsqueeze disparity to (C, H, W)
disparity_map = disparity_map[None, :, :] / 1024.0 disparity_map = disparity_map[None, :, :] / 1024.0
valid_mask = None valid_mask = None
return disparity_map, valid_mask return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple: def __getitem__(self, index: int) -> T1:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -1114,7 +1121,7 @@ class InStereo2k(StereoMatchingDataset): ...@@ -1114,7 +1121,7 @@ class InStereo2k(StereoMatchingDataset):
If a ``valid_mask`` is generated within the ``transforms`` parameter, If a ``valid_mask`` is generated within the ``transforms`` parameter,
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
""" """
return super().__getitem__(index) return cast(T1, super().__getitem__(index))
class ETH3DStereo(StereoMatchingDataset): class ETH3DStereo(StereoMatchingDataset):
...@@ -1169,7 +1176,7 @@ class ETH3DStereo(StereoMatchingDataset): ...@@ -1169,7 +1176,7 @@ class ETH3DStereo(StereoMatchingDataset):
_has_built_in_disparity_mask = True _has_built_in_disparity_mask = True
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None): def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
verify_str_arg(split, "split", valid_values=("train", "test")) verify_str_arg(split, "split", valid_values=("train", "test"))
...@@ -1189,7 +1196,7 @@ class ETH3DStereo(StereoMatchingDataset): ...@@ -1189,7 +1196,7 @@ class ETH3DStereo(StereoMatchingDataset):
disparity_pattern = str(root / anot_dir / "*" / "disp0GT.pfm") disparity_pattern = str(root / anot_dir / "*" / "disp0GT.pfm")
self._disparities = self._scan_pairs(disparity_pattern, None) self._disparities = self._scan_pairs(disparity_pattern, None)
def _read_disparity(self, file_path: str) -> Tuple: def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
# test split has no disparity maps # test split has no disparity maps
if file_path is None: if file_path is None:
return None, None return None, None
...@@ -1201,7 +1208,7 @@ class ETH3DStereo(StereoMatchingDataset): ...@@ -1201,7 +1208,7 @@ class ETH3DStereo(StereoMatchingDataset):
valid_mask = np.asarray(valid_mask).astype(bool) valid_mask = np.asarray(valid_mask).astype(bool)
return disparity_map, valid_mask return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple: def __getitem__(self, index: int) -> T2:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -1214,4 +1221,4 @@ class ETH3DStereo(StereoMatchingDataset): ...@@ -1214,4 +1221,4 @@ class ETH3DStereo(StereoMatchingDataset):
generate a valid mask. generate a valid mask.
Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test. Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
""" """
return super().__getitem__(index) return cast(T2, super().__getitem__(index))
...@@ -23,10 +23,10 @@ class CelebA(VisionDataset): ...@@ -23,10 +23,10 @@ class CelebA(VisionDataset):
or ``landmarks``. Can also be a list to output a tuple with all specified target types. or ``landmarks``. Can also be a list to output a tuple with all specified target types.
The targets represent: The targets represent:
- ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes - ``attr`` (Tensor shape=(40,) dtype=int): binary (0, 1) labels for attributes
- ``identity`` (int): label for each person (data points with the same identity are the same person) - ``identity`` (int): label for each person (data points with the same identity are the same person)
- ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height) - ``bbox`` (Tensor shape=(4,) dtype=int): bounding box (x, y, width, height)
- ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x, - ``landmarks`` (Tensor shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y) righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)
Defaults to ``attr``. If empty, ``None`` will be returned as target. Defaults to ``attr``. If empty, ``None`` will be returned as target.
...@@ -41,7 +41,7 @@ class CelebA(VisionDataset): ...@@ -41,7 +41,7 @@ class CelebA(VisionDataset):
""" """
base_folder = "celeba" base_folder = "celeba"
# There currently does not appear to be a easy way to extract 7z in python (without introducing additional # There currently does not appear to be an easy way to extract 7z in python (without introducing additional
# dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
# right now. # right now.
file_list = [ file_list = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment