ROIPool.h 1.4 KB
Newer Older
1
2
#pragma once

3
#include "cpu/vision_cpu.h"
4
5

#ifdef WITH_CUDA
6
#include "cuda/vision_cuda.h"
7
8
9
10
11
#endif

std::tuple<at::Tensor, at::Tensor> ROIPool_forward(
    const at::Tensor& input,
    const at::Tensor& rois,
12
13
14
    const double spatial_scale,
    const int64_t pooled_height,
    const int64_t pooled_width) {
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
  if (input.type().is_cuda()) {
#ifdef WITH_CUDA
    return ROIPool_forward_cuda(
        input, rois, spatial_scale, pooled_height, pooled_width);
#else
    AT_ERROR("Not compiled with GPU support");
#endif
  }
  return ROIPool_forward_cpu(
      input, rois, spatial_scale, pooled_height, pooled_width);
}

at::Tensor ROIPool_backward(
    const at::Tensor& grad,
    const at::Tensor& rois,
    const at::Tensor& argmax,
    const float spatial_scale,
    const int pooled_height,
    const int pooled_width,
    const int batch_size,
    const int channels,
    const int height,
    const int width) {
  if (grad.type().is_cuda()) {
#ifdef WITH_CUDA
    return ROIPool_backward_cuda(
        grad,
        rois,
        argmax,
        spatial_scale,
        pooled_height,
        pooled_width,
        batch_size,
        channels,
        height,
        width);
#else
    AT_ERROR("Not compiled with GPU support");
#endif
  }
  return ROIPool_backward_cpu(
      grad,
      rois,
      argmax,
      spatial_scale,
      pooled_height,
      pooled_width,
      batch_size,
      channels,
      height,
      width);
}