Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
3f412c39
Commit
3f412c39
authored
Dec 09, 2018
by
Kai Chen
Browse files
support half tensors
parent
826a5613
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
19 additions
and
45 deletions
+19
-45
mmdet/ops/roi_align/functions/roi_align.py
mmdet/ops/roi_align/functions/roi_align.py
+6
-6
mmdet/ops/roi_align/src/roi_align_cuda.cpp
mmdet/ops/roi_align/src/roi_align_cuda.cpp
+1
-1
mmdet/ops/roi_align/src/roi_align_kernel.cu
mmdet/ops/roi_align/src/roi_align_kernel.cu
+3
-16
mmdet/ops/roi_pool/functions/roi_pool.py
mmdet/ops/roi_pool/functions/roi_pool.py
+5
-6
mmdet/ops/roi_pool/src/roi_pool_cuda.cpp
mmdet/ops/roi_pool/src/roi_pool_cuda.cpp
+1
-1
mmdet/ops/roi_pool/src/roi_pool_kernel.cu
mmdet/ops/roi_pool/src/roi_pool_kernel.cu
+3
-15
No files found.
mmdet/ops/roi_align/functions/roi_align.py
View file @
3f412c39
from
torch.autograd
import
Function
,
Variable
from
torch.autograd
import
Function
from
..
import
roi_align_cuda
from
..
import
roi_align_cuda
...
@@ -49,11 +49,11 @@ class RoIAlignFunction(Function):
...
@@ -49,11 +49,11 @@ class RoIAlignFunction(Function):
grad_input
=
grad_rois
=
None
grad_input
=
grad_rois
=
None
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
needs_input_grad
[
0
]:
grad_input
=
Variable
(
grad_input
=
rois
.
new_zeros
(
batch_size
,
num_channels
,
data_height
,
rois
.
new
(
batch_size
,
num_channels
,
data_height
,
data_width
)
data_width
)
.
zero_
())
roi_align_cuda
.
backward
(
grad_output
.
contiguous
(),
rois
,
out_h
,
roi_align_cuda
.
backward
(
grad_output
,
rois
,
out_h
,
out_w
,
out_w
,
spatial_scale
,
sample_num
,
spatial_scale
,
sample_num
,
grad_input
)
grad_input
)
return
grad_input
,
grad_rois
,
None
,
None
,
None
return
grad_input
,
grad_rois
,
None
,
None
,
None
...
...
mmdet/ops/roi_align/src/roi_align_cuda.cpp
View file @
3f412c39
#include <torch/
torch
.h>
#include <torch/
extension
.h>
#include <cmath>
#include <cmath>
#include <vector>
#include <vector>
...
...
mmdet/ops/roi_align/src/roi_align_kernel.cu
View file @
3f412c39
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCAtomics.cuh>
using
namespace
at
;
// temporal fix for pytorch<=0.4.1 (see #9848)
#define CUDA_1D_KERNEL_LOOP(i, n) \
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
i += blockDim.x * gridDim.x)
...
@@ -144,12 +142,7 @@ int ROIAlignForwardLaucher(const at::Tensor features, const at::Tensor rois,
...
@@ -144,12 +142,7 @@ int ROIAlignForwardLaucher(const at::Tensor features, const at::Tensor rois,
sample_num
,
channels
,
height
,
width
,
pooled_height
,
sample_num
,
channels
,
height
,
width
,
pooled_height
,
pooled_width
,
top_data
);
pooled_width
,
top_data
);
}));
}));
cudaError_t
err
=
cudaGetLastError
();
THCudaCheck
(
cudaGetLastError
());
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"cudaCheckError() failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
return
1
;
return
1
;
}
}
...
@@ -280,8 +273,7 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
...
@@ -280,8 +273,7 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
at
::
Tensor
bottom_grad
)
{
at
::
Tensor
bottom_grad
)
{
const
int
output_size
=
num_rois
*
pooled_height
*
pooled_width
*
channels
;
const
int
output_size
=
num_rois
*
pooled_height
*
pooled_width
*
channels
;
// TODO: use AT_DISPATCH_FLOATING_TYPES_AND_HALF when atomicAdd is resolved
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
AT_DISPATCH_FLOATING_TYPES
(
top_grad
.
type
(),
"ROIAlignLaucherBackward"
,
([
&
]
{
top_grad
.
type
(),
"ROIAlignLaucherBackward"
,
([
&
]
{
const
scalar_t
*
top_diff
=
top_grad
.
data
<
scalar_t
>
();
const
scalar_t
*
top_diff
=
top_grad
.
data
<
scalar_t
>
();
const
scalar_t
*
rois_data
=
rois
.
data
<
scalar_t
>
();
const
scalar_t
*
rois_data
=
rois
.
data
<
scalar_t
>
();
...
@@ -297,11 +289,6 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
...
@@ -297,11 +289,6 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
channels
,
height
,
width
,
pooled_height
,
pooled_width
,
channels
,
height
,
width
,
pooled_height
,
pooled_width
,
bottom_diff
);
bottom_diff
);
}));
}));
cudaError_t
err
=
cudaGetLastError
();
THCudaCheck
(
cudaGetLastError
());
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"cudaCheckError() failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
return
1
;
return
1
;
}
}
mmdet/ops/roi_pool/functions/roi_pool.py
View file @
3f412c39
...
@@ -24,9 +24,8 @@ class RoIPoolFunction(Function):
...
@@ -24,9 +24,8 @@ class RoIPoolFunction(Function):
num_channels
=
features
.
size
(
1
)
num_channels
=
features
.
size
(
1
)
num_rois
=
rois
.
size
(
0
)
num_rois
=
rois
.
size
(
0
)
out_size
=
(
num_rois
,
num_channels
,
out_h
,
out_w
)
out_size
=
(
num_rois
,
num_channels
,
out_h
,
out_w
)
output
=
features
.
new_zeros
(
*
out_size
)
output
=
features
.
new_zeros
(
out_size
)
argmax
=
features
.
new_zeros
(
out_size
,
dtype
=
torch
.
int
)
argmax
=
features
.
new_zeros
(
*
out_size
,
dtype
=
torch
.
int
)
roi_pool_cuda
.
forward
(
features
,
rois
,
out_h
,
out_w
,
spatial_scale
,
roi_pool_cuda
.
forward
(
features
,
rois
,
out_h
,
out_w
,
spatial_scale
,
output
,
argmax
)
output
,
argmax
)
ctx
.
spatial_scale
=
spatial_scale
ctx
.
spatial_scale
=
spatial_scale
...
@@ -46,9 +45,9 @@ class RoIPoolFunction(Function):
...
@@ -46,9 +45,9 @@ class RoIPoolFunction(Function):
grad_input
=
grad_rois
=
None
grad_input
=
grad_rois
=
None
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
needs_input_grad
[
0
]:
grad_input
=
grad_output
.
new
(
feature_size
)
.
zero_
()
grad_input
=
grad_output
.
new
_zeros
(
feature_size
)
roi_pool_cuda
.
backward
(
grad_output
,
rois
,
argmax
,
spatial_scale
,
roi_pool_cuda
.
backward
(
grad_output
.
contiguous
()
,
rois
,
argmax
,
grad_input
)
spatial_scale
,
grad_input
)
return
grad_input
,
grad_rois
,
None
,
None
return
grad_input
,
grad_rois
,
None
,
None
...
...
mmdet/ops/roi_pool/src/roi_pool_cuda.cpp
View file @
3f412c39
#include <torch/
torch
.h>
#include <torch/
extension
.h>
#include <cmath>
#include <cmath>
#include <vector>
#include <vector>
...
...
mmdet/ops/roi_pool/src/roi_pool_kernel.cu
View file @
3f412c39
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCAtomics.cuh>
using
namespace
at
;
// temporal fix for pytorch<=0.4.1 (see #9848)
#define CUDA_1D_KERNEL_LOOP(i, n) \
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
i += blockDim.x * gridDim.x)
...
@@ -100,11 +98,7 @@ int ROIPoolForwardLaucher(const at::Tensor features, const at::Tensor rois,
...
@@ -100,11 +98,7 @@ int ROIPoolForwardLaucher(const at::Tensor features, const at::Tensor rois,
channels
,
height
,
width
,
pooled_h
,
pooled_w
,
top_data
,
channels
,
height
,
width
,
pooled_h
,
pooled_w
,
top_data
,
argmax_data
);
argmax_data
);
}));
}));
cudaError_t
err
=
cudaGetLastError
();
THCudaCheck
(
cudaGetLastError
());
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"cudaCheckError() failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
return
1
;
return
1
;
}
}
...
@@ -139,8 +133,7 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
...
@@ -139,8 +133,7 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
const
int
pooled_w
,
at
::
Tensor
bottom_grad
)
{
const
int
pooled_w
,
at
::
Tensor
bottom_grad
)
{
const
int
output_size
=
num_rois
*
pooled_h
*
pooled_w
*
channels
;
const
int
output_size
=
num_rois
*
pooled_h
*
pooled_w
*
channels
;
// TODO: use AT_DISPATCH_FLOATING_TYPES_AND_HALF when atomicAdd is resolved
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
AT_DISPATCH_FLOATING_TYPES
(
top_grad
.
type
(),
"ROIPoolLaucherBackward"
,
([
&
]
{
top_grad
.
type
(),
"ROIPoolLaucherBackward"
,
([
&
]
{
const
scalar_t
*
top_diff
=
top_grad
.
data
<
scalar_t
>
();
const
scalar_t
*
top_diff
=
top_grad
.
data
<
scalar_t
>
();
const
scalar_t
*
rois_data
=
rois
.
data
<
scalar_t
>
();
const
scalar_t
*
rois_data
=
rois
.
data
<
scalar_t
>
();
...
@@ -158,11 +151,6 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
...
@@ -158,11 +151,6 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
scalar_t
(
spatial_scale
),
channels
,
height
,
width
,
pooled_h
,
scalar_t
(
spatial_scale
),
channels
,
height
,
width
,
pooled_h
,
pooled_w
,
bottom_diff
);
pooled_w
,
bottom_diff
);
}));
}));
cudaError_t
err
=
cudaGetLastError
();
THCudaCheck
(
cudaGetLastError
());
if
(
cudaSuccess
!=
err
)
{
fprintf
(
stderr
,
"cudaCheckError() failed : %s
\n
"
,
cudaGetErrorString
(
err
));
exit
(
-
1
);
}
return
1
;
return
1
;
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment