Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
02920db2
Unverified
Commit
02920db2
authored
Dec 13, 2020
by
BigBigDream
Committed by
GitHub
Dec 13, 2020
Browse files
fix roi_align ci for parrots (#708)
* fix roi_align ci for parrots * fix lint
parent
b7136e39
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
572 additions
and
33 deletions
+572
-33
mmcv/ops/csrc/parrots/roi_align.cpp
mmcv/ops/csrc/parrots/roi_align.cpp
+94
-17
mmcv/ops/csrc/parrots/roi_align_cpu.cpp
mmcv/ops/csrc/parrots/roi_align_cpu.cpp
+430
-0
mmcv/ops/csrc/parrots/roi_align_cuda.cu
mmcv/ops/csrc/parrots/roi_align_cuda.cu
+14
-15
mmcv/ops/csrc/parrots_cpp_helper.hpp
mmcv/ops/csrc/parrots_cpp_helper.hpp
+29
-0
tests/test_ops/test_roi_align.py
tests/test_ops/test_roi_align.py
+5
-1
No files found.
mmcv/ops/csrc/parrots/roi_align.cpp
View file @
02920db2
// Copyright (c) 2018, SenseTime.
// Copyright (c) 2018, SenseTime.
#include "parrots_cpp_helper.hpp"
#include "parrots_cpp_helper.hpp"
void
ROIAlignForwardCUDAKernelLauncher
(
const
DArrayLite
input
,
void
ROIAlignForwardCPULauncher
(
DArrayLite
input
,
DArrayLite
rois
,
const
DArrayLite
rois
,
DArrayLite
output
,
DArrayLite
output
,
DArrayLite
argmax_y
,
DArrayLite
argmax_x
,
int
aligned_height
,
int
aligned_width
,
float
spatial_scale
,
int
sampling_ratio
,
int
pool_mode
,
bool
aligned
);
void
ROIAlignBackwardCPULauncher
(
DArrayLite
grad_output
,
DArrayLite
rois
,
DArrayLite
argmax_y
,
DArrayLite
argmax_x
,
DArrayLite
argmax_y
,
DArrayLite
argmax_x
,
int
aligned_height
,
int
aligned_width
,
DArrayLite
grad_input
,
int
aligned_height
,
float
spatial_scale
,
int
sampling_ratio
,
int
aligned_width
,
float
spatial_scale
,
int
pool_mode
,
bool
aligned
,
int
sampling_ratio
,
int
pool_mode
,
cudaStream_t
stream
);
bool
aligned
);
void
ROIAlignForwardCUDAKernelLauncher
(
DArrayLite
input
,
DArrayLite
rois
,
DArrayLite
output
,
DArrayLite
argmax_y
,
DArrayLite
argmax_x
,
int
aligned_height
,
int
aligned_width
,
float
spatial_scale
,
int
sampling_ratio
,
int
pool_mode
,
bool
aligned
,
cudaStream_t
stream
);
void
ROIAlignBackwardCUDAKernelLauncher
(
void
ROIAlignBackwardCUDAKernelLauncher
(
const
DArrayLite
grad_output
,
const
DArrayLite
rois
,
DArrayLite
grad_output
,
DArrayLite
rois
,
DArrayLite
argmax_y
,
const
DArrayLite
argmax_y
,
const
DArrayLite
argmax_x
,
DArrayLite
grad_input
,
DArrayLite
argmax_x
,
DArrayLite
grad_input
,
int
aligned_height
,
int
aligned_height
,
int
aligned_width
,
float
spatial_scale
,
int
aligned_width
,
float
spatial_scale
,
int
sampling_ratio
,
int
pool_mode
,
int
sampling_ratio
,
int
pool_mode
,
bool
aligned
,
cudaStream_t
stream
);
bool
aligned
,
cudaStream_t
stream
);
void
roi_align_forward_cpu
(
HostContext
&
ctx
,
const
SSElement
&
attr
,
const
OperatorBase
::
in_list_t
&
ins
,
OperatorBase
::
out_list_t
&
outs
)
{
int
aligned_height
;
int
aligned_width
;
float
spatial_scale
;
int
sampling_ratio
;
int
pool_mode
;
bool
aligned
;
SSAttrs
(
attr
)
.
get
<
int
>
(
"aligned_height"
,
aligned_height
)
.
get
<
int
>
(
"aligned_width"
,
aligned_width
)
.
get
<
float
>
(
"spatial_scale"
,
spatial_scale
)
.
get
<
int
>
(
"sampling_ratio"
,
sampling_ratio
)
.
get
<
int
>
(
"pool_mode"
,
pool_mode
)
.
get
<
bool
>
(
"aligned"
,
aligned
)
.
done
();
auto
&
input
=
ins
[
0
];
auto
&
rois
=
ins
[
1
];
auto
&
output
=
outs
[
0
];
auto
&
argmax_y
=
outs
[
1
];
auto
&
argmax_x
=
outs
[
2
];
ROIAlignForwardCPULauncher
(
input
,
rois
,
output
,
argmax_y
,
argmax_x
,
aligned_height
,
aligned_width
,
spatial_scale
,
sampling_ratio
,
pool_mode
,
aligned
);
}
void
roi_align_backward_cpu
(
HostContext
&
ctx
,
const
SSElement
&
attr
,
const
OperatorBase
::
in_list_t
&
ins
,
OperatorBase
::
out_list_t
&
outs
)
{
int
aligned_height
;
int
aligned_width
;
float
spatial_scale
;
int
sampling_ratio
;
int
pool_mode
;
bool
aligned
;
SSAttrs
(
attr
)
.
get
<
int
>
(
"aligned_height"
,
aligned_height
)
.
get
<
int
>
(
"aligned_width"
,
aligned_width
)
.
get
<
float
>
(
"spatial_scale"
,
spatial_scale
)
.
get
<
int
>
(
"sampling_ratio"
,
sampling_ratio
)
.
get
<
int
>
(
"pool_mode"
,
pool_mode
)
.
get
<
bool
>
(
"aligned"
,
aligned
)
.
done
();
auto
&
grad_output
=
ins
[
0
];
auto
&
rois
=
ins
[
1
];
auto
&
argmax_y
=
ins
[
2
];
auto
&
argmax_x
=
ins
[
3
];
auto
&
grad_input
=
outs
[
0
];
ROIAlignBackwardCPULauncher
(
grad_output
,
rois
,
argmax_y
,
argmax_x
,
grad_input
,
aligned_height
,
aligned_width
,
spatial_scale
,
sampling_ratio
,
pool_mode
,
aligned
);
}
void
roi_align_forward_cuda
(
CudaContext
&
ctx
,
const
SSElement
&
attr
,
void
roi_align_forward_cuda
(
CudaContext
&
ctx
,
const
SSElement
&
attr
,
const
OperatorBase
::
in_list_t
&
ins
,
const
OperatorBase
::
in_list_t
&
ins
,
...
@@ -33,8 +104,8 @@ void roi_align_forward_cuda(CudaContext& ctx, const SSElement& attr,
...
@@ -33,8 +104,8 @@ void roi_align_forward_cuda(CudaContext& ctx, const SSElement& attr,
.
get
<
bool
>
(
"aligned"
,
aligned
)
.
get
<
bool
>
(
"aligned"
,
aligned
)
.
done
();
.
done
();
const
auto
&
input
=
ins
[
0
];
auto
&
input
=
ins
[
0
];
const
auto
&
rois
=
ins
[
1
];
auto
&
rois
=
ins
[
1
];
auto
&
output
=
outs
[
0
];
auto
&
output
=
outs
[
0
];
auto
&
argmax_y
=
outs
[
1
];
auto
&
argmax_y
=
outs
[
1
];
auto
&
argmax_x
=
outs
[
2
];
auto
&
argmax_x
=
outs
[
2
];
...
@@ -63,10 +134,10 @@ void roi_align_backward_cuda(CudaContext& ctx, const SSElement& attr,
...
@@ -63,10 +134,10 @@ void roi_align_backward_cuda(CudaContext& ctx, const SSElement& attr,
.
get
<
bool
>
(
"aligned"
,
aligned
)
.
get
<
bool
>
(
"aligned"
,
aligned
)
.
done
();
.
done
();
const
auto
&
grad_output
=
ins
[
0
];
auto
&
grad_output
=
ins
[
0
];
const
auto
&
rois
=
ins
[
1
];
auto
&
rois
=
ins
[
1
];
const
auto
&
argmax_y
=
ins
[
2
];
auto
&
argmax_y
=
ins
[
2
];
const
auto
&
argmax_x
=
ins
[
3
];
auto
&
argmax_x
=
ins
[
3
];
auto
&
grad_input
=
outs
[
0
];
auto
&
grad_input
=
outs
[
0
];
cudaStream_t
stream
=
getStreamNative
<
CudaDevice
>
(
ctx
.
getStream
());
cudaStream_t
stream
=
getStreamNative
<
CudaDevice
>
(
ctx
.
getStream
());
...
@@ -84,7 +155,10 @@ PARROTS_EXTENSION_REGISTER(roi_align_forward)
...
@@ -84,7 +155,10 @@ PARROTS_EXTENSION_REGISTER(roi_align_forward)
.
attr
(
"aligned"
)
.
attr
(
"aligned"
)
.
input
(
2
)
.
input
(
2
)
.
output
(
3
)
.
output
(
3
)
.
apply
(
roi_align_forward_cpu
)
#ifdef PARROTS_USE_CUDA
.
apply
(
roi_align_forward_cuda
)
.
apply
(
roi_align_forward_cuda
)
#endif
.
done
();
.
done
();
PARROTS_EXTENSION_REGISTER
(
roi_align_backward
)
PARROTS_EXTENSION_REGISTER
(
roi_align_backward
)
...
@@ -96,5 +170,8 @@ PARROTS_EXTENSION_REGISTER(roi_align_backward)
...
@@ -96,5 +170,8 @@ PARROTS_EXTENSION_REGISTER(roi_align_backward)
.
attr
(
"aligned"
)
.
attr
(
"aligned"
)
.
input
(
4
)
.
input
(
4
)
.
output
(
1
)
.
output
(
1
)
.
apply
(
roi_align_backward_cpu
)
#ifdef PARROTS_USE_CUDA
.
apply
(
roi_align_backward_cuda
)
.
apply
(
roi_align_backward_cuda
)
#endif
.
done
();
.
done
();
mmcv/ops/csrc/parrots/roi_align_cpu.cpp
0 → 100644
View file @
02920db2
// Modified from
// https://github.com/facebookresearch/detectron2/tree/master/detectron2/layers/csrc/ROIAlign
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#include <iostream>
#include "parrots_cpp_helper.hpp"
// implementation taken from Caffe2
template
<
typename
T
>
struct
PreCalc
{
int
pos1
;
int
pos2
;
int
pos3
;
int
pos4
;
T
w1
;
T
w2
;
T
w3
;
T
w4
;
};
template
<
typename
T
>
void
pre_calc_for_bilinear_interpolate
(
const
int
height
,
const
int
width
,
const
int
pooled_height
,
const
int
pooled_width
,
const
int
iy_upper
,
const
int
ix_upper
,
T
roi_start_h
,
T
roi_start_w
,
T
bin_size_h
,
T
bin_size_w
,
int
roi_bin_grid_h
,
int
roi_bin_grid_w
,
std
::
vector
<
PreCalc
<
T
>>&
pre_calc
)
{
int
pre_calc_index
=
0
;
for
(
int
ph
=
0
;
ph
<
pooled_height
;
ph
++
)
{
for
(
int
pw
=
0
;
pw
<
pooled_width
;
pw
++
)
{
for
(
int
iy
=
0
;
iy
<
iy_upper
;
iy
++
)
{
const
T
yy
=
roi_start_h
+
ph
*
bin_size_h
+
static_cast
<
T
>
(
iy
+
.5
f
)
*
bin_size_h
/
static_cast
<
T
>
(
roi_bin_grid_h
);
// e.g., 0.5, 1.5
for
(
int
ix
=
0
;
ix
<
ix_upper
;
ix
++
)
{
const
T
xx
=
roi_start_w
+
pw
*
bin_size_w
+
static_cast
<
T
>
(
ix
+
.5
f
)
*
bin_size_w
/
static_cast
<
T
>
(
roi_bin_grid_w
);
T
x
=
xx
;
T
y
=
yy
;
// deal with: inverse elements are out of feature map boundary
if
(
y
<
-
1.0
||
y
>
height
||
x
<
-
1.0
||
x
>
width
)
{
// empty
PreCalc
<
T
>
pc
;
pc
.
pos1
=
0
;
pc
.
pos2
=
0
;
pc
.
pos3
=
0
;
pc
.
pos4
=
0
;
pc
.
w1
=
0
;
pc
.
w2
=
0
;
pc
.
w3
=
0
;
pc
.
w4
=
0
;
pre_calc
[
pre_calc_index
]
=
pc
;
pre_calc_index
+=
1
;
continue
;
}
if
(
y
<=
0
)
{
y
=
0
;
}
if
(
x
<=
0
)
{
x
=
0
;
}
int
y_low
=
(
int
)
y
;
int
x_low
=
(
int
)
x
;
int
y_high
;
int
x_high
;
if
(
y_low
>=
height
-
1
)
{
y_high
=
y_low
=
height
-
1
;
y
=
(
T
)
y_low
;
}
else
{
y_high
=
y_low
+
1
;
}
if
(
x_low
>=
width
-
1
)
{
x_high
=
x_low
=
width
-
1
;
x
=
(
T
)
x_low
;
}
else
{
x_high
=
x_low
+
1
;
}
T
ly
=
y
-
y_low
;
T
lx
=
x
-
x_low
;
T
hy
=
1.
-
ly
,
hx
=
1.
-
lx
;
T
w1
=
hy
*
hx
,
w2
=
hy
*
lx
,
w3
=
ly
*
hx
,
w4
=
ly
*
lx
;
// save weights and indices
PreCalc
<
T
>
pc
;
pc
.
pos1
=
y_low
*
width
+
x_low
;
pc
.
pos2
=
y_low
*
width
+
x_high
;
pc
.
pos3
=
y_high
*
width
+
x_low
;
pc
.
pos4
=
y_high
*
width
+
x_high
;
pc
.
w1
=
w1
;
pc
.
w2
=
w2
;
pc
.
w3
=
w3
;
pc
.
w4
=
w4
;
pre_calc
[
pre_calc_index
]
=
pc
;
pre_calc_index
+=
1
;
}
}
}
}
}
template
<
typename
T
>
void
ROIAlignForward
(
const
int
nthreads
,
const
T
*
input
,
const
T
*
rois
,
T
*
output
,
T
*
argmax_y
,
T
*
argmax_x
,
const
int
pooled_height
,
const
int
pooled_width
,
const
T
spatial_scale
,
const
int
sampling_ratio
,
const
int
pool_mode
,
// 0 - max pool, 1 - avg pool
const
bool
aligned
,
const
int
channels
,
const
int
height
,
const
int
width
)
{
int
n_rois
=
nthreads
/
channels
/
pooled_width
/
pooled_height
;
// (n, c, ph, pw) is an element in the pooled output
// can be parallelized using omp
// #pragma omp parallel for num_threads(32)
for
(
int
n
=
0
;
n
<
n_rois
;
n
++
)
{
int
index_n
=
n
*
channels
*
pooled_width
*
pooled_height
;
const
T
*
offset_rois
=
rois
+
n
*
5
;
int
roi_batch_ind
=
offset_rois
[
0
];
// Do not use rounding; this implementation detail is critical
T
offset
=
aligned
?
(
T
)
0.5
:
(
T
)
0.0
;
T
roi_start_w
=
offset_rois
[
1
]
*
spatial_scale
-
offset
;
T
roi_start_h
=
offset_rois
[
2
]
*
spatial_scale
-
offset
;
T
roi_end_w
=
offset_rois
[
3
]
*
spatial_scale
-
offset
;
T
roi_end_h
=
offset_rois
[
4
]
*
spatial_scale
-
offset
;
T
roi_width
=
roi_end_w
-
roi_start_w
;
T
roi_height
=
roi_end_h
-
roi_start_h
;
if
(
aligned
)
{
PARROTS_CHECKARGS
(
roi_width
>=
0
&&
roi_height
>=
0
)
<<
"ROIs in ROIAlign cannot have non-negative size!"
;
}
else
{
// for backward-compatibility only
roi_width
=
std
::
max
(
roi_width
,
(
T
)
1.
);
roi_height
=
std
::
max
(
roi_height
,
(
T
)
1.
);
}
T
bin_size_h
=
static_cast
<
T
>
(
roi_height
)
/
static_cast
<
T
>
(
pooled_height
);
T
bin_size_w
=
static_cast
<
T
>
(
roi_width
)
/
static_cast
<
T
>
(
pooled_width
);
// We use roi_bin_grid to sample the grid and mimic integral
int
roi_bin_grid_h
=
(
sampling_ratio
>
0
)
?
sampling_ratio
:
ceil
(
roi_height
/
pooled_height
);
// e.g., = 2
int
roi_bin_grid_w
=
(
sampling_ratio
>
0
)
?
sampling_ratio
:
ceil
(
roi_width
/
pooled_width
);
// When the grid is empty, output zeros == 0/1, instead of NaN.
const
T
count
=
std
::
max
(
roi_bin_grid_h
*
roi_bin_grid_w
,
1
);
// e.g. = 4
// we want to precalculate indices and weights shared by all channels,
// this is the key point of optimization
std
::
vector
<
PreCalc
<
T
>>
pre_calc
(
roi_bin_grid_h
*
roi_bin_grid_w
*
pooled_width
*
pooled_height
);
pre_calc_for_bilinear_interpolate
(
height
,
width
,
pooled_height
,
pooled_width
,
roi_bin_grid_h
,
roi_bin_grid_w
,
roi_start_h
,
roi_start_w
,
bin_size_h
,
bin_size_w
,
roi_bin_grid_h
,
roi_bin_grid_w
,
pre_calc
);
for
(
int
c
=
0
;
c
<
channels
;
c
++
)
{
int
index_n_c
=
index_n
+
c
*
pooled_width
*
pooled_height
;
const
T
*
offset_input
=
input
+
(
roi_batch_ind
*
channels
+
c
)
*
height
*
width
;
int
pre_calc_index
=
0
;
for
(
int
ph
=
0
;
ph
<
pooled_height
;
ph
++
)
{
for
(
int
pw
=
0
;
pw
<
pooled_width
;
pw
++
)
{
int
index
=
index_n_c
+
ph
*
pooled_width
+
pw
;
T
output_val
=
0.
;
T
maxval
=
-
10000
;
T
maxidx_y
=
-
1.
f
,
maxidx_x
=
-
1.
f
;
for
(
int
iy
=
0
;
iy
<
roi_bin_grid_h
;
iy
++
)
{
const
T
y
=
roi_start_h
+
ph
*
bin_size_h
+
static_cast
<
T
>
(
iy
+
.5
f
)
*
bin_size_h
/
static_cast
<
T
>
(
roi_bin_grid_h
);
for
(
int
ix
=
0
;
ix
<
roi_bin_grid_w
;
ix
++
)
{
const
T
x
=
roi_start_w
+
pw
*
bin_size_w
+
static_cast
<
T
>
(
ix
+
.5
f
)
*
bin_size_w
/
static_cast
<
T
>
(
roi_bin_grid_w
);
PreCalc
<
T
>
pc
=
pre_calc
[
pre_calc_index
];
T
val
=
pc
.
w1
*
offset_input
[
pc
.
pos1
]
+
pc
.
w2
*
offset_input
[
pc
.
pos2
]
+
pc
.
w3
*
offset_input
[
pc
.
pos3
]
+
pc
.
w4
*
offset_input
[
pc
.
pos4
];
if
(
val
>
maxval
)
{
maxval
=
val
;
maxidx_y
=
y
;
maxidx_x
=
x
;
}
output_val
+=
val
;
pre_calc_index
+=
1
;
}
}
if
(
pool_mode
==
0
)
{
// We do max pooling inside a bin
output
[
index
]
=
maxval
;
argmax_y
[
index
]
=
maxidx_y
;
argmax_x
[
index
]
=
maxidx_x
;
}
else
if
(
pool_mode
==
1
)
{
// We do average (integral) pooling inside a bin
output
[
index
]
=
output_val
/
count
;
}
// if
}
// for pw
}
// for ph
}
// for c
}
// for n
}
template
<
typename
T
>
void
bilinear_interpolate_gradient
(
const
int
height
,
const
int
width
,
T
y
,
T
x
,
T
&
w1
,
T
&
w2
,
T
&
w3
,
T
&
w4
,
int
&
x_low
,
int
&
x_high
,
int
&
y_low
,
int
&
y_high
,
const
int
index
/* index for debug only*/
)
{
// deal with cases that inverse elements are out of feature map boundary
if
(
y
<
-
1.0
||
y
>
height
||
x
<
-
1.0
||
x
>
width
)
{
// empty
w1
=
w2
=
w3
=
w4
=
0.
;
x_low
=
x_high
=
y_low
=
y_high
=
-
1
;
return
;
}
if
(
y
<=
0
)
y
=
0
;
if
(
x
<=
0
)
x
=
0
;
y_low
=
(
int
)
y
;
x_low
=
(
int
)
x
;
if
(
y_low
>=
height
-
1
)
{
y_high
=
y_low
=
height
-
1
;
y
=
(
T
)
y_low
;
}
else
{
y_high
=
y_low
+
1
;
}
if
(
x_low
>=
width
-
1
)
{
x_high
=
x_low
=
width
-
1
;
x
=
(
T
)
x_low
;
}
else
{
x_high
=
x_low
+
1
;
}
T
ly
=
y
-
y_low
;
T
lx
=
x
-
x_low
;
T
hy
=
1.
-
ly
,
hx
=
1.
-
lx
;
// reference in forward
// T v1 = input[y_low * width + x_low];
// T v2 = input[y_low * width + x_high];
// T v3 = input[y_high * width + x_low];
// T v4 = input[y_high * width + x_high];
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
w1
=
hy
*
hx
,
w2
=
hy
*
lx
,
w3
=
ly
*
hx
,
w4
=
ly
*
lx
;
return
;
}
template
<
class
T
>
inline
void
add
(
T
*
address
,
const
T
&
val
)
{
*
address
+=
val
;
}
template
<
typename
T
>
void
ROIAlignBackward
(
const
int
nthreads
,
const
T
*
grad_output
,
const
T
*
rois
,
const
T
*
argmax_y
,
const
T
*
argmax_x
,
T
*
grad_input
,
const
int
pooled_height
,
const
int
pooled_width
,
const
T
spatial_scale
,
const
int
sampling_ratio
,
const
int
pool_mode
,
// 0 - max pool, 1 - avg pool
const
bool
aligned
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
n_stride
,
const
int
c_stride
,
const
int
h_stride
,
const
int
w_stride
)
{
for
(
int
index
=
0
;
index
<
nthreads
;
index
++
)
{
// (n, c, ph, pw) is an element in the pooled output
int
pw
=
index
%
pooled_width
;
int
ph
=
(
index
/
pooled_width
)
%
pooled_height
;
int
c
=
(
index
/
pooled_width
/
pooled_height
)
%
channels
;
int
n
=
index
/
pooled_width
/
pooled_height
/
channels
;
const
T
*
offset_rois
=
rois
+
n
*
5
;
int
roi_batch_ind
=
offset_rois
[
0
];
// Do not use rounding; this implementation detail is critical
T
offset
=
aligned
?
(
T
)
0.5
:
(
T
)
0.0
;
T
roi_start_w
=
offset_rois
[
1
]
*
spatial_scale
-
offset
;
T
roi_start_h
=
offset_rois
[
2
]
*
spatial_scale
-
offset
;
T
roi_end_w
=
offset_rois
[
3
]
*
spatial_scale
-
offset
;
T
roi_end_h
=
offset_rois
[
4
]
*
spatial_scale
-
offset
;
T
roi_width
=
roi_end_w
-
roi_start_w
;
T
roi_height
=
roi_end_h
-
roi_start_h
;
if
(
aligned
)
{
PARROTS_CHECKARGS
(
roi_width
>=
0
&&
roi_height
>=
0
)
<<
"ROIs in ROIAlign do not have non-negative size!"
;
}
else
{
// for backward-compatibility only
roi_width
=
std
::
max
(
roi_width
,
(
T
)
1.
);
roi_height
=
std
::
max
(
roi_height
,
(
T
)
1.
);
}
T
bin_size_h
=
static_cast
<
T
>
(
roi_height
)
/
static_cast
<
T
>
(
pooled_height
);
T
bin_size_w
=
static_cast
<
T
>
(
roi_width
)
/
static_cast
<
T
>
(
pooled_width
);
T
*
offset_grad_input
=
grad_input
+
((
roi_batch_ind
*
channels
+
c
)
*
height
*
width
);
int
output_offset
=
n
*
n_stride
+
c
*
c_stride
;
const
T
*
offset_grad_output
=
grad_output
+
output_offset
;
const
T
grad_output_this_bin
=
offset_grad_output
[
ph
*
h_stride
+
pw
*
w_stride
];
if
(
pool_mode
==
0
)
{
// We do max pooling inside a bin
T
y
=
argmax_y
[
index
],
x
=
argmax_x
[
index
];
if
(
y
!=
-
1.
f
)
{
T
w1
,
w2
,
w3
,
w4
;
int
x_low
,
x_high
,
y_low
,
y_high
;
bilinear_interpolate_gradient
(
height
,
width
,
y
,
x
,
w1
,
w2
,
w3
,
w4
,
x_low
,
x_high
,
y_low
,
y_high
,
index
);
T
g1
=
grad_output_this_bin
*
w1
;
T
g2
=
grad_output_this_bin
*
w2
;
T
g3
=
grad_output_this_bin
*
w3
;
T
g4
=
grad_output_this_bin
*
w4
;
if
(
x_low
>=
0
&&
x_high
>=
0
&&
y_low
>=
0
&&
y_high
>=
0
)
{
// atomic add is not needed for now since it is single threaded
add
(
offset_grad_input
+
y_low
*
width
+
x_low
,
static_cast
<
T
>
(
g1
));
add
(
offset_grad_input
+
y_low
*
width
+
x_high
,
static_cast
<
T
>
(
g2
));
add
(
offset_grad_input
+
y_high
*
width
+
x_low
,
static_cast
<
T
>
(
g3
));
add
(
offset_grad_input
+
y_high
*
width
+
x_high
,
static_cast
<
T
>
(
g4
));
}
// if
}
// mode
}
else
if
(
pool_mode
==
1
)
{
// We do average (integral) pooling inside a bin
// We use roi_bin_grid to sample the grid and mimic integral
int
roi_bin_grid_h
=
(
sampling_ratio
>
0
)
?
sampling_ratio
:
ceil
(
roi_height
/
pooled_height
);
// e.g., = 2
int
roi_bin_grid_w
=
(
sampling_ratio
>
0
)
?
sampling_ratio
:
ceil
(
roi_width
/
pooled_width
);
const
T
count
=
roi_bin_grid_h
*
roi_bin_grid_w
;
// e.g. = 4
for
(
int
iy
=
0
;
iy
<
roi_bin_grid_h
;
iy
++
)
{
const
T
y
=
roi_start_h
+
ph
*
bin_size_h
+
static_cast
<
T
>
(
iy
+
.5
f
)
*
bin_size_h
/
static_cast
<
T
>
(
roi_bin_grid_h
);
// e.g., 0.5, 1.5
for
(
int
ix
=
0
;
ix
<
roi_bin_grid_w
;
ix
++
)
{
const
T
x
=
roi_start_w
+
pw
*
bin_size_w
+
static_cast
<
T
>
(
ix
+
.5
f
)
*
bin_size_w
/
static_cast
<
T
>
(
roi_bin_grid_w
);
T
w1
,
w2
,
w3
,
w4
;
int
x_low
,
x_high
,
y_low
,
y_high
;
bilinear_interpolate_gradient
(
height
,
width
,
y
,
x
,
w1
,
w2
,
w3
,
w4
,
x_low
,
x_high
,
y_low
,
y_high
,
index
);
T
g1
=
grad_output_this_bin
*
w1
/
count
;
T
g2
=
grad_output_this_bin
*
w2
/
count
;
T
g3
=
grad_output_this_bin
*
w3
/
count
;
T
g4
=
grad_output_this_bin
*
w4
/
count
;
if
(
x_low
>=
0
&&
x_high
>=
0
&&
y_low
>=
0
&&
y_high
>=
0
)
{
// atomic add is not needed for now since it is single threaded
add
(
offset_grad_input
+
y_low
*
width
+
x_low
,
static_cast
<
T
>
(
g1
));
add
(
offset_grad_input
+
y_low
*
width
+
x_high
,
static_cast
<
T
>
(
g2
));
add
(
offset_grad_input
+
y_high
*
width
+
x_low
,
static_cast
<
T
>
(
g3
));
add
(
offset_grad_input
+
y_high
*
width
+
x_high
,
static_cast
<
T
>
(
g4
));
}
// if
}
// ix
}
// iy
}
// mode
}
// for
}
// ROIAlignBackward
void
ROIAlignForwardCPULauncher
(
DArrayLite
input
,
DArrayLite
rois
,
DArrayLite
output
,
DArrayLite
argmax_y
,
DArrayLite
argmax_x
,
int
aligned_height
,
int
aligned_width
,
float
spatial_scale
,
int
sampling_ratio
,
int
pool_mode
,
bool
aligned
)
{
int
output_size
=
output
.
size
();
int
channels
=
input
.
dim
(
1
);
int
height
=
input
.
dim
(
2
);
int
width
=
input
.
dim
(
3
);
PARROTS_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
elemType
().
prim
(),
([
&
]
{
ROIAlignForward
<
scalar_t
>
(
output_size
,
input
.
ptr
<
scalar_t
>
(),
rois
.
ptr
<
scalar_t
>
(),
output
.
ptr
<
scalar_t
>
(),
argmax_y
.
ptr
<
scalar_t
>
(),
argmax_x
.
ptr
<
scalar_t
>
(),
aligned_height
,
aligned_width
,
static_cast
<
scalar_t
>
(
spatial_scale
),
sampling_ratio
,
pool_mode
,
aligned
,
channels
,
height
,
width
);
}));
}
void
ROIAlignBackwardCPULauncher
(
DArrayLite
grad_output
,
DArrayLite
rois
,
DArrayLite
argmax_y
,
DArrayLite
argmax_x
,
DArrayLite
grad_input
,
int
aligned_height
,
int
aligned_width
,
float
spatial_scale
,
int
sampling_ratio
,
int
pool_mode
,
bool
aligned
)
{
int
output_size
=
grad_output
.
size
();
int
channels
=
grad_input
.
dim
(
1
);
int
height
=
grad_input
.
dim
(
2
);
int
width
=
grad_input
.
dim
(
3
);
// get stride values to ensure indexing into gradients is correct.
int
n_stride
=
grad_output
.
stride
(
0
);
int
c_stride
=
grad_output
.
stride
(
1
);
int
h_stride
=
grad_output
.
stride
(
2
);
int
w_stride
=
grad_output
.
stride
(
3
);
PARROTS_DISPATCH_FLOATING_TYPES_AND_HALF
(
grad_output
.
elemType
().
prim
(),
([
&
]
{
ROIAlignBackward
<
scalar_t
>
(
output_size
,
grad_output
.
ptr
<
scalar_t
>
(),
rois
.
ptr
<
scalar_t
>
(),
argmax_y
.
ptr
<
scalar_t
>
(),
argmax_x
.
ptr
<
scalar_t
>
(),
grad_input
.
ptr
<
scalar_t
>
(),
aligned_height
,
aligned_width
,
static_cast
<
scalar_t
>
(
spatial_scale
),
sampling_ratio
,
pool_mode
,
aligned
,
channels
,
height
,
width
,
n_stride
,
c_stride
,
h_stride
,
w_stride
);
}));
}
mmcv/ops/csrc/parrots/roi_align_cuda.cu
View file @
02920db2
#include "parrots_cuda_helper.hpp"
#include "parrots_cuda_helper.hpp"
#include "roi_align_cuda_kernel.cuh"
#include "roi_align_cuda_kernel.cuh"
void
ROIAlignForwardCUDAKernelLauncher
(
const
DArrayLite
input
,
void
ROIAlignForwardCUDAKernelLauncher
(
DArrayLite
input
,
DArrayLite
rois
,
const
DArrayLite
rois
,
DArrayLite
output
,
DArrayLite
output
,
DArrayLite
argmax_y
,
DArrayLite
argmax_y
,
DArrayLite
argmax_x
,
DArrayLite
argmax_x
,
int
aligned_height
,
int
aligned_height
,
int
aligned_width
,
int
aligned_width
,
float
spatial_scale
,
float
spatial_scale
,
int
sampling_ratio
,
int
sampling_ratio
,
int
pool_mode
,
int
pool_mode
,
bool
aligned
,
bool
aligned
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
int
output_size
=
output
.
size
();
int
output_size
=
output
.
size
();
int
channels
=
input
.
dim
(
1
);
int
channels
=
input
.
dim
(
1
);
int
height
=
input
.
dim
(
2
);
int
height
=
input
.
dim
(
2
);
...
@@ -20,18 +19,18 @@ void ROIAlignForwardCUDAKernelLauncher(const DArrayLite input,
...
@@ -20,18 +19,18 @@ void ROIAlignForwardCUDAKernelLauncher(const DArrayLite input,
output_size
,
input
.
ptr
<
scalar_t
>
(),
rois
.
ptr
<
scalar_t
>
(),
output_size
,
input
.
ptr
<
scalar_t
>
(),
rois
.
ptr
<
scalar_t
>
(),
output
.
ptr
<
scalar_t
>
(),
argmax_y
.
ptr
<
scalar_t
>
(),
output
.
ptr
<
scalar_t
>
(),
argmax_y
.
ptr
<
scalar_t
>
(),
argmax_x
.
ptr
<
scalar_t
>
(),
aligned_height
,
aligned_width
,
argmax_x
.
ptr
<
scalar_t
>
(),
aligned_height
,
aligned_width
,
spatial_scale
,
sampling_ratio
,
pool_mode
,
aligned
,
channels
,
static_cast
<
scalar_t
>
(
spatial_scale
)
,
sampling_ratio
,
pool_mode
,
height
,
width
);
aligned
,
channels
,
height
,
width
);
}));
}));
PARROTS_CUDA_CHECK
(
cudaGetLastError
());
PARROTS_CUDA_CHECK
(
cudaGetLastError
());
}
}
void
ROIAlignBackwardCUDAKernelLauncher
(
void
ROIAlignBackwardCUDAKernelLauncher
(
const
DArrayLite
grad_output
,
const
DArrayLite
rois
,
DArrayLite
grad_output
,
DArrayLite
rois
,
DArrayLite
argmax_y
,
const
DArrayLite
argmax_y
,
const
DArrayLite
argmax_x
,
DArrayLite
grad_input
,
DArrayLite
argmax_x
,
DArrayLite
grad_input
,
int
aligned_height
,
int
aligned_height
,
int
aligned_width
,
float
spatial_scale
,
int
aligned_width
,
float
spatial_scale
,
int
sampling_ratio
,
int
pool_mode
,
int
sampling_ratio
,
int
pool_mode
,
bool
aligned
,
cudaStream_t
stream
)
{
bool
aligned
,
cudaStream_t
stream
)
{
int
output_size
=
grad_output
.
size
();
int
output_size
=
grad_output
.
size
();
int
channels
=
grad_input
.
dim
(
1
);
int
channels
=
grad_input
.
dim
(
1
);
int
height
=
grad_input
.
dim
(
2
);
int
height
=
grad_input
.
dim
(
2
);
...
@@ -44,8 +43,8 @@ void ROIAlignBackwardCUDAKernelLauncher(
...
@@ -44,8 +43,8 @@ void ROIAlignBackwardCUDAKernelLauncher(
output_size
,
grad_output
.
ptr
<
scalar_t
>
(),
rois
.
ptr
<
scalar_t
>
(),
output_size
,
grad_output
.
ptr
<
scalar_t
>
(),
rois
.
ptr
<
scalar_t
>
(),
argmax_y
.
ptr
<
scalar_t
>
(),
argmax_x
.
ptr
<
scalar_t
>
(),
argmax_y
.
ptr
<
scalar_t
>
(),
argmax_x
.
ptr
<
scalar_t
>
(),
grad_input
.
ptr
<
scalar_t
>
(),
aligned_height
,
aligned_width
,
grad_input
.
ptr
<
scalar_t
>
(),
aligned_height
,
aligned_width
,
spatial_scale
,
sampling_ratio
,
pool_mode
,
aligned
,
channels
,
static_cast
<
scalar_t
>
(
spatial_scale
)
,
sampling_ratio
,
pool_mode
,
height
,
width
);
aligned
,
channels
,
height
,
width
);
}));
}));
PARROTS_CUDA_CHECK
(
cudaGetLastError
());
PARROTS_CUDA_CHECK
(
cudaGetLastError
());
...
...
mmcv/ops/csrc/parrots_cpp_helper.hpp
View file @
02920db2
...
@@ -8,4 +8,33 @@
...
@@ -8,4 +8,33 @@
using
namespace
parrots
;
using
namespace
parrots
;
#define PARROTS_PRIVATE_CASE_TYPE(prim_type, type, ...) \
case prim_type: { \
using scalar_t = type; \
return __VA_ARGS__(); \
}
#define PARROTS_DISPATCH_FLOATING_TYPES(TYPE, ...) \
[&] { \
const auto& the_type = TYPE; \
switch (the_type) { \
PARROTS_PRIVATE_CASE_TYPE(Prim::Float64, double, __VA_ARGS__) \
PARROTS_PRIVATE_CASE_TYPE(Prim::Float32, float, __VA_ARGS__) \
default: \
PARROTS_NOTSUPPORTED; \
} \
}()
#define PARROTS_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, ...) \
[&] { \
const auto& the_type = TYPE; \
switch (the_type) { \
PARROTS_PRIVATE_CASE_TYPE(Prim::Float64, double, __VA_ARGS__) \
PARROTS_PRIVATE_CASE_TYPE(Prim::Float32, float, __VA_ARGS__) \
PARROTS_PRIVATE_CASE_TYPE(Prim::Float16, float16, __VA_ARGS__) \
default: \
PARROTS_NOTSUPPORTED; \
} \
}()
#endif // PARROTS_CPP_HELPER
#endif // PARROTS_CPP_HELPER
tests/test_ops/test_roi_align.py
View file @
02920db2
...
@@ -56,6 +56,10 @@ def _test_roialign_gradcheck(device, dtype):
...
@@ -56,6 +56,10 @@ def _test_roialign_gradcheck(device, dtype):
froipool
=
RoIAlign
((
pool_h
,
pool_w
),
spatial_scale
,
sampling_ratio
)
froipool
=
RoIAlign
((
pool_h
,
pool_w
),
spatial_scale
,
sampling_ratio
)
if
torch
.
__version__
==
'parrots'
:
gradcheck
(
froipool
,
(
x
,
rois
),
no_grads
=
[
rois
],
delta
=
1e-5
,
pt_atol
=
1e-5
)
else
:
gradcheck
(
froipool
,
(
x
,
rois
),
eps
=
1e-5
,
atol
=
1e-5
)
gradcheck
(
froipool
,
(
x
,
rois
),
eps
=
1e-5
,
atol
=
1e-5
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment