Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
91da9643
"...gradio_app/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "bb373980e27a90c6f117a1f6d752ce64367a65e1"
Commit
91da9643
authored
Aug 13, 2024
by
limm
Browse files
support v2.1.0
parent
6f674c7e
Changes
139
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1214 additions
and
2206 deletions
+1214
-2206
mmcv/ops/csrc/pytorch/cuda/filtered_lrelu.cu
mmcv/ops/csrc/pytorch/cuda/filtered_lrelu.cu
+67
-55
mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu
mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu
+5
-4
mmcv/ops/csrc/pytorch/focal_loss.cpp
mmcv/ops/csrc/pytorch/focal_loss.cpp
+87
-0
mmcv/ops/csrc/pytorch/mlu/ball_query_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/ball_query_mlu.cpp
+47
-0
mmcv/ops/csrc/pytorch/mlu/bbox_overlaps_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/bbox_overlaps_mlu.cpp
+13
-57
mmcv/ops/csrc/pytorch/mlu/box_iou_rotated.cpp
mmcv/ops/csrc/pytorch/mlu/box_iou_rotated.cpp
+54
-0
mmcv/ops/csrc/pytorch/mlu/carafe_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/carafe_mlu.cpp
+41
-262
mmcv/ops/csrc/pytorch/mlu/deform_roi_pool_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/deform_roi_pool_mlu.cpp
+69
-253
mmcv/ops/csrc/pytorch/mlu/diff_iou_rotated_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/diff_iou_rotated_mlu.cpp
+55
-0
mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp
+67
-222
mmcv/ops/csrc/pytorch/mlu/iou3d_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/iou3d_mlu.cpp
+36
-101
mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.cpp
mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.cpp
+150
-0
mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h
mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h
+144
-0
mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp
+73
-364
mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp
+38
-108
mmcv/ops/csrc/pytorch/mlu/nms_rotated_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/nms_rotated_mlu.cpp
+55
-0
mmcv/ops/csrc/pytorch/mlu/psamask_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/psamask_mlu.cpp
+15
-213
mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp
+68
-89
mmcv/ops/csrc/pytorch/mlu/roi_align_rotated_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/roi_align_rotated_mlu.cpp
+43
-156
mmcv/ops/csrc/pytorch/mlu/roiaware_pool3d_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/roiaware_pool3d_mlu.cpp
+87
-322
No files found.
mmcv/ops/csrc/pytorch/cuda/filtered_lrelu.cu
View file @
91da9643
...
@@ -672,12 +672,12 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
...
@@ -672,12 +672,12 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Combine signs.
// Combine signs.
uint32_t
s
=
sx
+
sy
+
sw
+
sz
;
uint32_t
s
=
sx
+
sy
+
sw
+
sz
;
s
<<=
(
signX
&
3
)
<<
1
;
s
<<=
(
signX
&
3
)
<<
1
;
#ifndef MMCV_WITH_HIP
#ifdef MMCV_WITH_HIP
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
#else
s
|=
__shfl_xor
(
s
,
1
);
s
|=
__shfl_xor
(
s
,
1
);
s
|=
__shfl_xor
(
s
,
2
);
s
|=
__shfl_xor
(
s
,
2
);
#else
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
#endif
#endif
// Write signs.
// Write signs.
...
@@ -725,13 +725,14 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
...
@@ -725,13 +725,14 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Combine signs.
// Combine signs.
uint32_t
s
=
sx
+
sy
+
sw
+
sz
;
uint32_t
s
=
sx
+
sy
+
sw
+
sz
;
s
<<=
(
signX
&
3
)
<<
1
;
s
<<=
(
signX
&
3
)
<<
1
;
#ifndef MMCV_WITH_HIP
#ifdef MMCV_WITH_HIP
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
#else
s
|=
__shfl_xor
(
s
,
1
);
s
|=
__shfl_xor
(
s
,
1
);
s
|=
__shfl_xor
(
s
,
2
);
s
|=
__shfl_xor
(
s
,
2
);
#else
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
#endif
#endif
// Write signs.
// Write signs.
if
((
uint32_t
)(
signY
+
0
)
<
sShapeMaxY
)
{
if
((
uint32_t
)(
signY
+
0
)
<
sShapeMaxY
)
{
p
.
s
[
si0
]
=
(
unsigned
char
)(
s
>>
0
);
p
.
s
[
si0
]
=
(
unsigned
char
)(
s
>>
0
);
...
@@ -861,13 +862,14 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
...
@@ -861,13 +862,14 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Combine signs.
// Combine signs.
int
s
=
sx
+
sy
;
int
s
=
sx
+
sy
;
s
<<=
signXo
;
s
<<=
signXo
;
#ifndef MMCV_WITH_HIP
#ifdef MMCV_WITH_HIP
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
#else
s
|=
__shfl_xor
(
s
,
1
);
s
|=
__shfl_xor
(
s
,
1
);
s
|=
__shfl_xor
(
s
,
2
);
s
|=
__shfl_xor
(
s
,
2
);
#else
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
#endif
#endif
// Write signs.
// Write signs.
if
((
uint32_t
)(
signY
+
0
)
<
sShapeMaxY
)
{
if
((
uint32_t
)(
signY
+
0
)
<
sShapeMaxY
)
{
p
.
s
[
si0
]
=
(
unsigned
char
)(
s
>>
0
);
p
.
s
[
si0
]
=
(
unsigned
char
)(
s
>>
0
);
...
@@ -895,13 +897,14 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
...
@@ -895,13 +897,14 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
// Combine signs.
// Combine signs.
int
s
=
sx
+
sy
;
int
s
=
sx
+
sy
;
s
<<=
signXo
;
s
<<=
signXo
;
#ifndef MMCV_WITH_HIP
#ifdef MMCV_WITH_HIP
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
#else
s
|=
__shfl_xor
(
s
,
1
);
s
|=
__shfl_xor
(
s
,
1
);
s
|=
__shfl_xor
(
s
,
2
);
s
|=
__shfl_xor
(
s
,
2
);
#else
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
s
|=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
#endif
#endif
// Write signs.
// Write signs.
if
((
uint32_t
)(
signY
+
0
)
<
sShapeMaxY
)
{
if
((
uint32_t
)(
signY
+
0
)
<
sShapeMaxY
)
{
p
.
s
[
si0
]
=
(
unsigned
char
)(
s
>>
0
);
p
.
s
[
si0
]
=
(
unsigned
char
)(
s
>>
0
);
...
@@ -1188,14 +1191,14 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
...
@@ -1188,14 +1191,14 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
}
}
if
((
uint32_t
)
signXb
<
p
.
swLimit
&&
if
((
uint32_t
)
signXb
<
p
.
swLimit
&&
(
uint32_t
)
signY
<
p
.
sShape
.
y
&&
signY
>=
minY
)
{
(
uint32_t
)
signY
<
p
.
sShape
.
y
&&
signY
>=
minY
)
{
#ifndef MMCV_WITH_HIP
#ifdef MMCV_WITH_HIP
s
+=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
// Coalesce.
s
+=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
// Coalesce.
#else
s
+=
__shfl_xor
(
s
,
1
);
// Coalesce.
s
+=
__shfl_xor
(
s
,
1
);
// Coalesce.
s
+=
__shfl_xor
(
s
,
2
);
// Coalesce.
s
+=
__shfl_xor
(
s
,
2
);
// Coalesce.
#else
s
+=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
// Coalesce.
s
+=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
// Coalesce.
#endif
#endif
p
.
s
[
si
]
=
s
;
// Write.
p
.
s
[
si
]
=
s
;
// Write.
}
}
}
else
{
}
else
{
// Determine and write sign.
// Determine and write sign.
...
@@ -1211,14 +1214,14 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
...
@@ -1211,14 +1214,14 @@ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) {
s
=
signXbit
*
2
;
s
=
signXbit
*
2
;
v
=
InternalType
<
T
>::
clamp
(
v
,
p
.
clamp
);
v
=
InternalType
<
T
>::
clamp
(
v
,
p
.
clamp
);
}
}
#ifndef MMCV_WITH_HIP
#ifdef MMCV_WITH_HIP
s
+=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
// Coalesce.
s
+=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
// Coalesce.
#else
s
+=
__shfl_xor
(
s
,
1
);
// Coalesce.
s
+=
__shfl_xor
(
s
,
1
);
// Coalesce.
s
+=
__shfl_xor
(
s
,
2
);
// Coalesce.
s
+=
__shfl_xor
(
s
,
2
);
// Coalesce.
#else
s
+=
__shfl_xor_sync
(
groupMask
,
s
,
1
);
// Coalesce.
s
+=
__shfl_xor_sync
(
groupMask
,
s
,
2
);
// Coalesce.
#endif
#endif
p
.
s
[
si
]
=
s
;
// Write.
p
.
s
[
si
]
=
s
;
// Write.
}
else
{
}
else
{
// Just compute the value.
// Just compute the value.
if
(
v
<
0.
f
)
v
*=
p
.
slope
;
if
(
v
<
0.
f
)
v
*=
p
.
slope
;
...
@@ -1438,17 +1441,18 @@ static __global__ void filtered_lrelu_act_kernel(
...
@@ -1438,17 +1441,18 @@ static __global__ void filtered_lrelu_act_kernel(
// Coalesce into threads 0 and 16 of warp.
// Coalesce into threads 0 and 16 of warp.
uint32_t
m
=
(
threadIdx
.
x
&
16
)
?
0xffff0000u
:
0x0000ffffu
;
uint32_t
m
=
(
threadIdx
.
x
&
16
)
?
0xffff0000u
:
0x0000ffffu
;
s
<<=
((
threadIdx
.
x
&
15
)
<<
1
);
// Shift into place.
s
<<=
((
threadIdx
.
x
&
15
)
<<
1
);
// Shift into place.
#ifndef MMCV_WITH_HIP
#ifdef MMCV_WITH_HIP
s
|=
__shfl_xor_sync
(
m
,
s
,
1
);
// Distribute.
s
|=
__shfl_xor
(
s
,
1
);
// Distribute.
s
|=
__shfl_xor_sync
(
m
,
s
,
2
);
s
|=
__shfl_xor_sync
(
m
,
s
,
4
);
s
|=
__shfl_xor_sync
(
m
,
s
,
8
);
#else
s
|=
__shfl_xor
(
s
,
1
);
// Distribute.
s
|=
__shfl_xor
(
s
,
2
);
s
|=
__shfl_xor
(
s
,
2
);
s
|=
__shfl_xor
(
s
,
4
);
s
|=
__shfl_xor
(
s
,
4
);
s
|=
__shfl_xor
(
s
,
8
);
s
|=
__shfl_xor
(
s
,
8
);
#else
s
|=
__shfl_xor_sync
(
m
,
s
,
1
);
// Distribute.
s
|=
__shfl_xor_sync
(
m
,
s
,
2
);
s
|=
__shfl_xor_sync
(
m
,
s
,
4
);
s
|=
__shfl_xor_sync
(
m
,
s
,
8
);
#endif
#endif
// Write signs if leader and in p.s.
// Write signs if leader and in p.s.
if
(
!
(
threadIdx
.
x
&
15
)
&&
x
<
p
.
sShape
.
x
)
// y is always in.
if
(
!
(
threadIdx
.
x
&
15
)
&&
x
<
p
.
sShape
.
x
)
// y is always in.
{
{
...
@@ -1627,7 +1631,6 @@ filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(
...
@@ -1627,7 +1631,6 @@ filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(
#endif
#endif
#endif
#endif
#if CUDA_VERSION < 10020
#if CUDA_VERSION < 10020
#undef BUILD_FILTERED_LRELU_OP
#undef BUILD_FILTERED_LRELU_OP
#define BUILD_FILTERED_LRELU_OP 0
#define BUILD_FILTERED_LRELU_OP 0
...
@@ -1673,11 +1676,15 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
...
@@ -1673,11 +1676,15 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
// Figure out how much shared memory is available on the device.
// Figure out how much shared memory is available on the device.
int
maxSharedBytes
=
0
;
int
maxSharedBytes
=
0
;
int
result
=
cudaDeviceGetAttribute
(
&
maxSharedBytes
,
#ifdef MMCV_WITH_HIP
// cudaDevAttrMaxSharedMemoryPerBlockOptin,
cudaDeviceGetAttribute
(
&
maxSharedBytes
,
// hipDeviceAttributeSharedMemPerBlockOptin,
hipDeviceAttributeMaxSharedMemoryPerBlock
,
hipDeviceAttributeMaxSharedMemoryPerBlock
,
x
.
device
().
index
());
x
.
device
().
index
());
#else
AT_CUDA_CHECK
(
cudaDeviceGetAttribute
(
&
maxSharedBytes
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
x
.
device
().
index
()));
#endif
int
sharedKB
=
maxSharedBytes
>>
10
;
int
sharedKB
=
maxSharedBytes
>>
10
;
// Populate enough launch parameters to check if a CUDA kernel exists.
// Populate enough launch parameters to check if a CUDA kernel exists.
...
@@ -1875,15 +1882,15 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
...
@@ -1875,15 +1882,15 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
p
.
tilesXrep
=
0
;
p
.
tilesXrep
=
0
;
p
.
tilesXdim
=
0
;
p
.
tilesXdim
=
0
;
}
}
#ifdef MMCV_WITH_HIP
AT_CUDA_CHECK
(
hipLaunchKernel
(
spec
.
setup
,
1
,
1024
,
args
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()));
#else
// Launch filter setup kernel.
// Launch filter setup kernel.
#ifndef MMCV_WITH_HIP
AT_CUDA_CHECK
(
cudaLaunchKernel
(
spec
.
setup
,
1
,
1024
,
args
,
0
,
AT_CUDA_CHECK
(
cudaLaunchKernel
(
spec
.
setup
,
1
,
1024
,
args
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()));
at
::
cuda
::
getCurrentCUDAStream
()));
#else
AT_CUDA_CHECK
(
hipLaunchKernel
(
spec
.
setup
,
1
,
1024
,
args
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()));
#endif
#endif
// Copy kernels to constant memory.
// Copy kernels to constant memory.
if
(
writeSigns
&&
!
readSigns
)
if
(
writeSigns
&&
!
readSigns
)
AT_CUDA_CHECK
((
copy_filters
(
at
::
cuda
::
getCurrentCUDAStream
())));
AT_CUDA_CHECK
((
copy_filters
(
at
::
cuda
::
getCurrentCUDAStream
())));
...
@@ -1895,11 +1902,15 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
...
@@ -1895,11 +1902,15 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
// Set cache and shared memory configurations for main kernel.
// Set cache and shared memory configurations for main kernel.
AT_CUDA_CHECK
(
cudaFuncSetCacheConfig
(
spec
.
exec
,
cudaFuncCachePreferShared
));
AT_CUDA_CHECK
(
cudaFuncSetCacheConfig
(
spec
.
exec
,
cudaFuncCachePreferShared
));
if
(
spec
.
dynamicSharedKB
)
// Need dynamically allocated shared memory?
if
(
spec
.
dynamicSharedKB
)
// Need dynamically allocated shared memory?
// AT_CUDA_CHECK(cudaFuncSetAttribute(
#ifdef MMCV_WITH_HIP
AT_CUDA_CHECK
(
hipFuncSetAttribute
(
AT_CUDA_CHECK
(
hipFuncSetAttribute
(
// spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize,
spec
.
exec
,
hipFuncAttributeMaxDynamicSharedMemorySize
,
spec
.
exec
,
hipFuncAttributeMaxDynamicSharedMemorySize
,
spec
.
dynamicSharedKB
<<
10
));
spec
.
dynamicSharedKB
<<
10
));
#else
AT_CUDA_CHECK
(
cudaFuncSetAttribute
(
spec
.
exec
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
spec
.
dynamicSharedKB
<<
10
));
#endif
AT_CUDA_CHECK
(
AT_CUDA_CHECK
(
cudaFuncSetSharedMemConfig
(
spec
.
exec
,
cudaSharedMemBankSizeFourByte
));
cudaFuncSetSharedMemConfig
(
spec
.
exec
,
cudaSharedMemBankSizeFourByte
));
...
@@ -1910,12 +1921,12 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
...
@@ -1910,12 +1921,12 @@ std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu_op(
{
{
p
.
blockZofs
=
zofs
;
p
.
blockZofs
=
zofs
;
int
subGz
=
std
::
min
(
maxSubGz
,
gz
-
zofs
);
int
subGz
=
std
::
min
(
maxSubGz
,
gz
-
zofs
);
#ifndef MMCV_WITH_HIP
#ifdef MMCV_WITH_HIP
AT_CUDA_CHECK
(
cudaLaunchKernel
(
spec
.
exec
,
dim3
(
gx
,
gy
,
subGz
),
bx
,
args
,
spec
.
dynamicSharedKB
<<
10
,
at
::
cuda
::
getCurrentCUDAStream
()));
#else
AT_CUDA_CHECK
(
hipLaunchKernel
(
spec
.
exec
,
dim3
(
gx
,
gy
,
subGz
),
bx
,
args
,
AT_CUDA_CHECK
(
hipLaunchKernel
(
spec
.
exec
,
dim3
(
gx
,
gy
,
subGz
),
bx
,
args
,
spec
.
dynamicSharedKB
<<
10
,
at
::
cuda
::
getCurrentCUDAStream
()));
#else
AT_CUDA_CHECK
(
cudaLaunchKernel
(
spec
.
exec
,
dim3
(
gx
,
gy
,
subGz
),
bx
,
args
,
spec
.
dynamicSharedKB
<<
10
,
spec
.
dynamicSharedKB
<<
10
,
at
::
cuda
::
getCurrentCUDAStream
()));
at
::
cuda
::
getCurrentCUDAStream
()));
#endif
#endif
...
@@ -2033,12 +2044,13 @@ torch::Tensor filtered_lrelu_act_op(torch::Tensor x, torch::Tensor si, int sx,
...
@@ -2033,12 +2044,13 @@ torch::Tensor filtered_lrelu_act_op(torch::Tensor x, torch::Tensor si, int sx,
gz
=
std
::
min
(
gz
,
gmax
);
gz
=
std
::
min
(
gz
,
gmax
);
// Launch.
// Launch.
#ifndef MMCV_WITH_HIP
#ifdef MMCV_WITH_HIP
AT_CUDA_CHECK
(
cudaLaunchKernel
(
func
,
dim3
(
gx
,
gy
,
gz
),
bx
,
args
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()));
#else
AT_CUDA_CHECK
(
hipLaunchKernel
(
func
,
dim3
(
gx
,
gy
,
gz
),
bx
,
args
,
0
,
AT_CUDA_CHECK
(
hipLaunchKernel
(
func
,
dim3
(
gx
,
gy
,
gz
),
bx
,
args
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()));
#else
AT_CUDA_CHECK
(
cudaLaunchKernel
(
func
,
dim3
(
gx
,
gy
,
gz
),
bx
,
args
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()));
at
::
cuda
::
getCurrentCUDAStream
()));
#endif
#endif
return
so
;
return
so
;
}
}
mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu
View file @
91da9643
...
@@ -734,12 +734,13 @@ torch::Tensor upfirdn2d_op(torch::Tensor x, torch::Tensor f, int upx, int upy,
...
@@ -734,12 +734,13 @@ torch::Tensor upfirdn2d_op(torch::Tensor x, torch::Tensor f, int upx, int upy,
// Launch CUDA kernel.
// Launch CUDA kernel.
void
*
args
[]
=
{
&
p
};
void
*
args
[]
=
{
&
p
};
#ifndef MMCV_WITH_HIP
#ifdef MMCV_WITH_HIP
AT_CUDA_CHECK
(
cudaLaunchKernel
(
spec
.
kernel
,
gridSize
,
blockSize
,
args
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()));
#else
AT_CUDA_CHECK
(
hipLaunchKernel
(
spec
.
kernel
,
gridSize
,
blockSize
,
args
,
0
,
AT_CUDA_CHECK
(
hipLaunchKernel
(
spec
.
kernel
,
gridSize
,
blockSize
,
args
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()));
#else
AT_CUDA_CHECK
(
cudaLaunchKernel
(
spec
.
kernel
,
gridSize
,
blockSize
,
args
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()));
at
::
cuda
::
getCurrentCUDAStream
()));
#endif
#endif
return
y
;
return
y
;
}
}
mmcv/ops/csrc/pytorch/focal_loss.cpp
View file @
91da9643
// Copyright (c) OpenMMLab. All rights reserved
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
#include "pytorch_device_registry.hpp"
#ifdef MMCV_WITH_DIOPI
#include <diopi/diopirt.h>
#include <diopi/functions.h>
#include <diopi/functions_mmcv.h>
#include "csrc_dipu/diopirt/diopirt_impl.h"
using
dipu
::
diopi_helper
::
toDiopiScalar
;
using
dipu
::
diopi_helper
::
toDiopiTensorHandle
;
#endif
void
sigmoid_focal_loss_forward_impl
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
void
sigmoid_focal_loss_forward_impl
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
Tensor
output
,
float
gamma
,
float
alpha
)
{
Tensor
output
,
float
gamma
,
float
alpha
)
{
...
@@ -29,15 +39,92 @@ void softmax_focal_loss_backward_impl(Tensor input, Tensor target,
...
@@ -29,15 +39,92 @@ void softmax_focal_loss_backward_impl(Tensor input, Tensor target,
buff
,
grad_input
,
gamma
,
alpha
);
buff
,
grad_input
,
gamma
,
alpha
);
}
}
#ifdef MMCV_WITH_DIOPI
void
sigmoid_focal_loss_forward_diopi
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
Tensor
output
,
float
gamma
,
float
alpha
)
{
auto
input_p
=
toDiopiTensorHandle
(
input
);
diopiDevice_t
device
;
diopiGetTensorDevice
(
input_p
,
&
device
);
if
(
device
==
diopi_host
)
{
sigmoid_focal_loss_forward_impl
(
input
,
target
,
weight
,
output
,
gamma
,
alpha
);
return
;
}
diopiContext
ctx
(
dipu
::
getCurrentDIPUStream
().
rawstream
());
diopiContextHandle_t
ch
=
&
ctx
;
auto
target_p
=
toDiopiTensorHandle
(
target
);
auto
weight_p
=
toDiopiTensorHandle
(
weight
);
auto
output_p
=
toDiopiTensorHandle
(
output
);
if
(
reinterpret_cast
<
void
*>
(
diopiSigmoidFocalLossMmcv
)
!=
nullptr
)
{
auto
ret
=
diopiSigmoidFocalLossMmcv
(
ch
,
output_p
,
input_p
,
target_p
,
weight_p
,
gamma
,
alpha
);
if
(
ret
==
diopiSuccess
)
return
;
}
LOG
(
WARNING
)
<<
"Fallback to cpu: mmcv ext op sigmoid_focal_loss_forward_impl"
;
auto
input_cpu
=
input
.
cpu
();
auto
target_cpu
=
target
.
cpu
();
auto
weight_cpu
=
weight
.
cpu
();
auto
output_cpu
=
output
.
cpu
();
sigmoid_focal_loss_forward_impl
(
input_cpu
,
target_cpu
,
weight_cpu
,
output_cpu
,
gamma
,
alpha
);
output
.
copy_
(
output_cpu
);
return
;
}
void
sigmoid_focal_loss_backward_diopi
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
Tensor
grad_input
,
float
gamma
,
float
alpha
)
{
auto
input_p
=
toDiopiTensorHandle
(
input
);
diopiDevice_t
device
;
diopiGetTensorDevice
(
input_p
,
&
device
);
if
(
device
==
diopi_host
)
{
sigmoid_focal_loss_backward_impl
(
input
,
target
,
weight
,
grad_input
,
gamma
,
alpha
);
return
;
}
diopiContext
ctx
(
dipu
::
getCurrentDIPUStream
().
rawstream
());
diopiContextHandle_t
ch
=
&
ctx
;
auto
target_p
=
toDiopiTensorHandle
(
target
);
auto
weight_p
=
toDiopiTensorHandle
(
weight
);
auto
grad_input_p
=
toDiopiTensorHandle
(
grad_input
);
if
(
reinterpret_cast
<
void
*>
(
diopiSigmoidFocalLossBackwardMmcv
)
!=
nullptr
)
{
auto
ret
=
diopiSigmoidFocalLossBackwardMmcv
(
ch
,
grad_input_p
,
input_p
,
target_p
,
weight_p
,
gamma
,
alpha
);
if
(
ret
==
diopiSuccess
)
return
;
}
LOG
(
WARNING
)
<<
"Fallback to cpu: mmcv ext op sigmoid_focal_loss_forward_impl"
;
auto
input_cpu
=
input
.
cpu
();
auto
target_cpu
=
target
.
cpu
();
auto
weight_cpu
=
weight
.
cpu
();
auto
grad_input_cpu
=
grad_input
.
cpu
();
sigmoid_focal_loss_backward_impl
(
input_cpu
,
target_cpu
,
weight_cpu
,
grad_input_cpu
,
gamma
,
alpha
);
grad_input
.
copy_
(
grad_input_cpu
);
return
;
}
#endif
void
sigmoid_focal_loss_forward
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
void
sigmoid_focal_loss_forward
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
Tensor
output
,
float
gamma
,
float
alpha
)
{
Tensor
output
,
float
gamma
,
float
alpha
)
{
#ifdef MMCV_WITH_DIOPI
sigmoid_focal_loss_forward_diopi
(
input
,
target
,
weight
,
output
,
gamma
,
alpha
);
#else
sigmoid_focal_loss_forward_impl
(
input
,
target
,
weight
,
output
,
gamma
,
alpha
);
sigmoid_focal_loss_forward_impl
(
input
,
target
,
weight
,
output
,
gamma
,
alpha
);
#endif
}
}
void
sigmoid_focal_loss_backward
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
void
sigmoid_focal_loss_backward
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
Tensor
grad_input
,
float
gamma
,
float
alpha
)
{
Tensor
grad_input
,
float
gamma
,
float
alpha
)
{
#ifdef MMCV_WITH_DIOPI
sigmoid_focal_loss_backward_diopi
(
input
,
target
,
weight
,
grad_input
,
gamma
,
alpha
);
#else
sigmoid_focal_loss_backward_impl
(
input
,
target
,
weight
,
grad_input
,
gamma
,
sigmoid_focal_loss_backward_impl
(
input
,
target
,
weight
,
grad_input
,
gamma
,
alpha
);
alpha
);
#endif
}
}
void
softmax_focal_loss_forward
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
void
softmax_focal_loss_forward
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
...
...
mmcv/ops/csrc/pytorch/mlu/ball_query_mlu.cpp
0 → 100644
View file @
91da9643
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"
void
ball_query_forward_mlu
(
int
b
,
int
n
,
int
m
,
float
min_radius
,
float
max_radius
,
int
nsample
,
const
Tensor
new_xyz
,
const
Tensor
xyz
,
Tensor
idx
)
{
auto
new_xyz_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
new_xyz
,
new_xyz
.
suggest_memory_format
());
auto
xyz_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
xyz
,
new_xyz
.
suggest_memory_format
());
auto
idx_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
idx
,
new_xyz
.
suggest_memory_format
());
MluOpTensorDescriptor
new_xyz_desc
,
xyz_desc
,
idx_desc
;
new_xyz_desc
.
set
(
new_xyz_contiguous
);
xyz_desc
.
set
(
xyz_contiguous
);
idx_desc
.
set
(
idx_contiguous
);
auto
new_xyz_impl
=
torch_mlu
::
getMluTensorImpl
(
new_xyz_contiguous
);
auto
xyz_impl
=
torch_mlu
::
getMluTensorImpl
(
xyz_contiguous
);
auto
idx_impl
=
torch_mlu
::
getMluTensorImpl
(
idx_contiguous
);
auto
new_xyz_ptr
=
new_xyz_impl
->
cnnlMalloc
();
auto
xyz_ptr
=
xyz_impl
->
cnnlMalloc
();
auto
idx_ptr
=
idx_impl
->
cnnlMalloc
();
auto
handle
=
mluOpGetCurrentHandle
();
TORCH_MLUOP_CHECK
(
mluOpBallQuery
(
handle
,
new_xyz_desc
.
desc
(),
new_xyz_ptr
,
xyz_desc
.
desc
(),
xyz_ptr
,
min_radius
,
max_radius
,
nsample
,
idx_desc
.
desc
(),
idx_ptr
));
}
void
ball_query_forward_impl
(
int
b
,
int
n
,
int
m
,
float
min_radius
,
float
max_radius
,
int
nsample
,
const
Tensor
new_xyz
,
const
Tensor
xyz
,
Tensor
idx
);
REGISTER_DEVICE_IMPL
(
ball_query_forward_impl
,
MLU
,
ball_query_forward_mlu
);
mmcv/ops/csrc/pytorch/mlu/bbox_overlaps_mlu.cpp
View file @
91da9643
...
@@ -10,36 +10,11 @@
...
@@ -10,36 +10,11 @@
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "mlu_common_helper.h"
#include "pytorch_mlu_helper.hpp"
void
KernelBBoxOverlaps
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
void
bbox_overlaps_mlu
(
const
Tensor
bboxes1
,
const
Tensor
bboxes2
,
Tensor
ious
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
int32_t
mode
,
const
bool
aligned
,
const
void
*
bbox1
,
const
void
*
bbox2
,
void
*
ious
,
const
int32_t
offset
)
{
const
int32_t
num_bbox1
,
const
int32_t
num_bbox2
,
const
int32_t
mode
,
const
bool
aligned
,
const
int32_t
offset
);
static
void
policyFunc
(
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
,
const
int32_t
batch_num_all
)
{
auto
union_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
auto
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
auto
core_num
=
union_num
*
core_dim
;
// Union1 policyFunc
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
k_dim
->
x
=
core_dim
;
auto
need_core_num
=
PAD_UP
(
batch_num_all
,
core_dim
);
k_dim
->
y
=
(
need_core_num
<
core_num
)
?
(
need_core_num
/
core_dim
)
:
union_num
;
k_dim
->
z
=
1
;
return
;
}
void
BBoxOverlapsMLUKernelLauncher
(
const
Tensor
bboxes1
,
const
Tensor
bboxes2
,
Tensor
ious
,
const
int32_t
mode
,
const
bool
aligned
,
const
int32_t
offset
)
{
// check dtype
// check dtype
TORCH_CHECK
(
TORCH_CHECK
(
bboxes1
.
scalar_type
()
==
at
::
kFloat
||
bboxes1
.
scalar_type
()
==
at
::
kHalf
,
bboxes1
.
scalar_type
()
==
at
::
kFloat
||
bboxes1
.
scalar_type
()
==
at
::
kHalf
,
...
@@ -63,38 +38,19 @@ void BBoxOverlapsMLUKernelLauncher(const Tensor bboxes1, const Tensor bboxes2,
...
@@ -63,38 +38,19 @@ void BBoxOverlapsMLUKernelLauncher(const Tensor bboxes1, const Tensor bboxes2,
return
;
return
;
}
}
// calculate task dimension
INITIAL_MLU_PARAM_WITH_TENSOR
(
bboxes1
);
cnrtDim3_t
k_dim
;
INITIAL_MLU_PARAM_WITH_TENSOR
(
bboxes2
);
cnrtFunctionType_t
k_type
;
INITIAL_MLU_PARAM_WITH_TENSOR
(
ious
);
policyFunc
(
&
k_dim
,
&
k_type
,
batch_num_all
);
// get compute
queu
e
// get compute
handl
e
cnrtQueue_t
queue
=
torch_mlu
::
getCurQueu
e
();
auto
handle
=
mluOpGetCurrentHandl
e
();
// get dtype of input
TORCH_MLUOP_CHECK
(
mluOpBboxOverlaps
(
cnrtDataType_t
d_type
=
torch_mlu
::
toCnrtDtype
(
bboxes1
.
dtype
());
handle
,
mode
,
aligned
,
offset
,
bboxes1_desc
.
desc
(),
bboxes1_ptr
,
bboxes2_desc
.
desc
(),
bboxes2_ptr
,
ious_desc
.
desc
(),
ious_ptr
));
// get ptr of tensors
auto
bboxes1_impl
=
torch_mlu
::
getMluTensorImpl
(
bboxes1
);
auto
bboxes1_ptr
=
bboxes1_impl
->
cnnlMalloc
();
auto
bboxes2_impl
=
torch_mlu
::
getMluTensorImpl
(
bboxes2
);
auto
bboxes2_ptr
=
bboxes2_impl
->
cnnlMalloc
();
auto
ious_impl
=
torch_mlu
::
getMluTensorImpl
(
ious
);
auto
ious_ptr
=
ious_impl
->
cnnlMalloc
();
// launch kernel
CNLOG
(
INFO
)
<<
"Launch Kernel MLUUnion1BboxOverlapsKernel"
;
CNLOG
(
INFO
)
<<
"kDim :[ "
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
" ]"
;
KernelBBoxOverlaps
(
k_dim
,
k_type
,
queue
,
d_type
,
bboxes1_ptr
,
bboxes2_ptr
,
ious_ptr
,
rows
,
cols
,
mode
,
aligned
,
offset
);
}
void
bbox_overlaps_mlu
(
const
Tensor
bboxes1
,
const
Tensor
bboxes2
,
Tensor
ious
,
const
int
mode
,
const
bool
aligned
,
const
int
offset
)
{
BBoxOverlapsMLUKernelLauncher
(
bboxes1
,
bboxes2
,
ious
,
mode
,
aligned
,
offset
);
}
}
void
bbox_overlaps_impl
(
const
Tensor
bboxes1
,
const
Tensor
bboxes2
,
Tensor
ious
,
void
bbox_overlaps_impl
(
const
Tensor
bboxes1
,
const
Tensor
bboxes2
,
Tensor
ious
,
const
int
mode
,
const
bool
aligned
,
const
int
offset
);
const
int
mode
,
const
bool
aligned
,
const
int
offset
);
REGISTER_DEVICE_IMPL
(
bbox_overlaps_impl
,
MLU
,
bbox_overlaps_mlu
);
REGISTER_DEVICE_IMPL
(
bbox_overlaps_impl
,
MLU
,
bbox_overlaps_mlu
);
mmcv/ops/csrc/pytorch/mlu/box_iou_rotated.cpp
0 → 100644
View file @
91da9643
/*************************************************************************
* Copyright (C) 2022 by Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"
void
BoxIouRotatedMLUKernelLauncher
(
const
Tensor
boxes1
,
const
Tensor
boxes2
,
Tensor
ious
,
const
int
mode_flag
,
const
bool
aligned
)
{
// get compute handle
auto
handle
=
mluOpGetCurrentHandle
();
auto
boxes1_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
boxes1
,
boxes1
.
suggest_memory_format
());
auto
boxes2_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
boxes2
,
boxes2
.
suggest_memory_format
());
auto
ious_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
ious
,
ious
.
suggest_memory_format
());
MluOpTensorDescriptor
boxes1_desc
,
boxes2_desc
,
ious_desc
;
boxes1_desc
.
set
(
boxes1_contiguous
);
boxes2_desc
.
set
(
boxes2_contiguous
);
ious_desc
.
set
(
ious_contiguous
);
auto
boxes1_impl
=
torch_mlu
::
getMluTensorImpl
(
boxes1_contiguous
);
auto
boxes2_impl
=
torch_mlu
::
getMluTensorImpl
(
boxes2_contiguous
);
auto
ious_impl
=
torch_mlu
::
getMluTensorImpl
(
ious_contiguous
);
auto
boxes1_ptr
=
boxes1_impl
->
cnnlMalloc
();
auto
boxes2_ptr
=
boxes2_impl
->
cnnlMalloc
();
auto
ious_ptr
=
ious_impl
->
cnnlMalloc
();
CNLOG
(
INFO
)
<<
"Call mluOpBoxIouRotated()."
;
TORCH_MLUOP_CHECK
(
mluOpBoxIouRotated
(
handle
,
mode_flag
,
aligned
,
boxes1_desc
.
desc
(),
boxes1_ptr
,
boxes2_desc
.
desc
(),
boxes2_ptr
,
ious_desc
.
desc
(),
ious_ptr
));
}
void
box_iou_rotated_mlu
(
const
Tensor
boxes1
,
const
Tensor
boxes2
,
Tensor
ious
,
const
int
mode_flag
,
const
bool
aligned
)
{
BoxIouRotatedMLUKernelLauncher
(
boxes1
,
boxes2
,
ious
,
mode_flag
,
aligned
);
}
void
box_iou_rotated_impl
(
const
Tensor
boxes1
,
const
Tensor
boxes2
,
Tensor
ious
,
const
int
mode_flag
,
const
bool
aligned
);
REGISTER_DEVICE_IMPL
(
box_iou_rotated_impl
,
MLU
,
box_iou_rotated_mlu
);
mmcv/ops/csrc/pytorch/mlu/carafe_mlu.cpp
View file @
91da9643
...
@@ -9,200 +9,13 @@
...
@@ -9,200 +9,13 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
*************************************************************************/
#include "carafe_utils.hpp"
#include "mlu_common_helper.h"
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void
KernelCarafeForward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
void
*
input
,
const
void
*
mask
,
const
CarafeForwardParam
&
param
,
const
CarafeForwardBlockDim
&
block_dim
,
const
CarafeForwardGridDim
&
grid_dim
,
void
*
output
);
void
KernelCarafeBackward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
cnrtDataType_t
dtype
,
const
void
*
input
,
const
void
*
mask
,
const
void
*
grad_output
,
void
*
grad_input
,
void
*
grad_mask
,
const
int
n
,
const
int
hi
,
const
int
wi
,
const
int
c
,
const
int
k_up
,
const
int
group
,
const
int
scale
);
// Get total NRAM usage and set strides of NRAM arrays.
static
void
getNramUsage
(
CarafeForwardParam
*
param
,
CarafeForwardBlockDim
*
block_dim
,
int
*
nram_usage
)
{
// input_nram[blkDim_(Hi+Kh)-1, blkDim_(Wi+Kw)-1, blkDim_G, blkDim_Cg]
block_dim
->
Hi
=
CEIL_DIV
(
block_dim
->
Ho
,
param
->
scale_factor
)
+
1
;
block_dim
->
Wi
=
CEIL_DIV
(
block_dim
->
Wo
,
param
->
scale_factor
)
+
1
;
param
->
input_nram_stride_g
=
PAD_UP
(
block_dim
->
Cg
,
param
->
align_size_NRAM
);
param
->
input_nram_stride_w
=
param
->
input_nram_stride_g
*
block_dim
->
G
;
param
->
input_nram_stride_h
=
(
block_dim
->
Wi
+
block_dim
->
Kw
-
1
)
*
param
->
input_nram_stride_w
;
param
->
input_nram_size
=
(
block_dim
->
Hi
+
block_dim
->
Kh
-
1
)
*
param
->
input_nram_stride_h
;
// mask_nram[blkDim_Ho, blkDim_Wo, blkDim_G, blkDim_Kh, blkDim_Kw]
param
->
mask_nram_stride_kh
=
block_dim
->
Kw
;
param
->
mask_nram_stride_g
=
block_dim
->
Kh
*
param
->
mask_nram_stride_kh
;
param
->
mask_nram_stride_w
=
block_dim
->
G
*
param
->
mask_nram_stride_g
;
param
->
mask_nram_stride_h
=
block_dim
->
Wo
*
param
->
mask_nram_stride_w
;
param
->
mask_nram_size
=
PAD_UP
(
block_dim
->
Ho
*
param
->
mask_nram_stride_h
,
param
->
align_size_NRAM
);
// output_nram[blkDim_Ho, blkDim_Wo, blkDim_(G*Cg)]
param
->
output_nram_stride_g
=
param
->
input_nram_stride_g
;
param
->
output_nram_stride_w
=
PAD_UP
(
param
->
input_nram_stride_w
,
param
->
align_size_NFU
);
param
->
output_nram_stride_h
=
block_dim
->
Wo
*
param
->
output_nram_stride_w
;
param
->
output_nram_size
=
block_dim
->
Ho
*
param
->
output_nram_stride_h
;
// sum_array[blkDim_(G*Cg)]
// ensure the last mul_const on Cg does not exceed memory boundary
int
sum_array_size_bang_mul_const
=
(
block_dim
->
G
-
1
)
*
param
->
input_nram_stride_g
+
PAD_UP
(
param
->
input_nram_stride_g
,
param
->
align_size_NFU
);
int
sum_array_size
=
std
::
max
(
param
->
output_nram_stride_w
,
sum_array_size_bang_mul_const
);
*
nram_usage
=
param
->
input_nram_size
+
param
->
mask_nram_size
+
param
->
output_nram_size
+
sum_array_size
;
}
// Policy Function for Forward
static
void
genPolicyForward
(
CarafeForwardParam
*
param
,
CarafeForwardBlockDim
*
block_dim
,
CarafeForwardGridDim
*
grid_dim
,
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
)
{
// device info
auto
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
auto
cluster_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
auto
core_num
=
core_dim
*
cluster_num
;
// maximum NRAM size as the number of <dtype>
auto
max_nram_size
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrNramSizePerMcore
)
/
param
->
dtype_size
;
// determine grid and block dimensions
// set initial values for block_dim and grid_dim
block_dim
->
Ho
=
param
->
Ho
;
block_dim
->
Wo
=
param
->
Wo
;
block_dim
->
Kh
=
param
->
kernel_size
;
block_dim
->
Kw
=
param
->
kernel_size
;
block_dim
->
G
=
param
->
group_size
;
block_dim
->
Cg
=
param
->
Cg
;
grid_dim
->
Ho
=
1
;
grid_dim
->
Wo
=
1
;
grid_dim
->
Kh
=
1
;
grid_dim
->
Kw
=
1
;
grid_dim
->
G
=
1
;
grid_dim
->
Cg
=
1
;
// decrease the block size to fit in the NRAM.
int
nram_usage
=
0
;
while
(
true
)
{
getNramUsage
(
param
,
block_dim
,
&
nram_usage
);
if
(
nram_usage
>
max_nram_size
)
{
// decrease Ho
// decrease block_Ho and block_Wo evenly
// so that the block is close to a square.
if
(
block_dim
->
Ho
>
1
&&
block_dim
->
Ho
>=
block_dim
->
Wo
)
{
grid_dim
->
Ho
+=
1
;
block_dim
->
Ho
=
CEIL_DIV
(
param
->
Ho
,
grid_dim
->
Ho
);
}
else
if
(
block_dim
->
Wo
>
1
&&
block_dim
->
Wo
>
block_dim
->
Ho
)
{
// decrease Wo
grid_dim
->
Wo
+=
1
;
block_dim
->
Wo
=
CEIL_DIV
(
param
->
Wo
,
grid_dim
->
Wo
);
}
else
if
(
block_dim
->
Kh
>
1
)
{
// decrease Kh
grid_dim
->
Kh
+=
1
;
block_dim
->
Kh
=
CEIL_DIV
(
param
->
kernel_size
,
grid_dim
->
Kh
);
// reset Hi, Wi to maximize NRAM usage
grid_dim
->
Ho
=
1
;
block_dim
->
Ho
=
param
->
Ho
;
grid_dim
->
Wo
=
1
;
block_dim
->
Wo
=
param
->
Wo
;
}
else
if
(
block_dim
->
Kw
>
1
)
{
// decrease Kw
grid_dim
->
Kw
+=
1
;
block_dim
->
Kw
=
CEIL_DIV
(
param
->
kernel_size
,
grid_dim
->
Kw
);
// reset Kh
grid_dim
->
Kh
=
1
;
block_dim
->
Kh
=
param
->
kernel_size
;
}
else
if
(
block_dim
->
G
>
1
)
{
// decrease G
grid_dim
->
G
+=
1
;
block_dim
->
G
=
CEIL_DIV
(
param
->
group_size
,
grid_dim
->
G
);
// reset Kw
grid_dim
->
Kw
=
1
;
block_dim
->
Kw
=
param
->
kernel_size
;
}
else
if
(
block_dim
->
Cg
>
1
)
{
// decrease block_Cg
// This is done in the last since c is the continuous dim
// (input layout is NHWC) and large c can improve
// IO & compute efficiency.
grid_dim
->
Cg
+=
1
;
block_dim
->
Cg
=
CEIL_DIV
(
param
->
Cg
,
grid_dim
->
Cg
);
// reset G
grid_dim
->
G
=
1
;
block_dim
->
G
=
param
->
group_size
;
}
else
{
// the block volume is one now, cannot decrease the block size anymore!
// this situation should not occur.
break
;
}
}
else
{
break
;
}
}
// define parameters depending on block_dim, grid_dim
param
->
block_Cg_NFU
=
PAD_UP
(
block_dim
->
Cg
,
param
->
align_size_NFU
);
// define host arrays' strides
// input[N,H,W,G,Cg]
param
->
input_stride_g
=
param
->
Cg
;
param
->
input_stride_w
=
param
->
Ci
;
param
->
input_stride_h
=
param
->
Wi
*
param
->
input_stride_w
;
param
->
input_stride_n
=
param
->
Hi
*
param
->
input_stride_h
;
// mask[N,Ho,Wo,G,Kh,Kw]
param
->
mask_stride_kh
=
param
->
kernel_size
;
param
->
mask_stride_g
=
param
->
kernel_size
*
param
->
mask_stride_kh
;
param
->
mask_stride_w
=
param
->
group_size
*
param
->
mask_stride_g
;
param
->
mask_stride_h
=
param
->
Wo
*
param
->
mask_stride_w
;
param
->
mask_stride_n
=
param
->
Ho
*
param
->
mask_stride_h
;
// output[N,Ho,Wo,G,Cg]
param
->
output_stride_g
=
param
->
Cg
;
param
->
output_stride_w
=
param
->
Ci
;
param
->
output_stride_h
=
param
->
Wo
*
param
->
output_stride_w
;
param
->
output_stride_n
=
param
->
Ho
*
param
->
output_stride_h
;
param
->
job_num
=
param
->
N
*
grid_dim
->
Ho
*
grid_dim
->
Wo
*
grid_dim
->
G
*
grid_dim
->
Cg
;
// determine task type and dims
*
k_type
=
CNRT_FUNC_TYPE_BLOCK
;
k_dim
->
x
=
std
::
min
(
param
->
job_num
,
static_cast
<
int
>
(
core_num
));
k_dim
->
y
=
1
;
k_dim
->
z
=
1
;
}
void
CARAFEForwardMLUKernelLauncher
(
const
Tensor
input
,
const
Tensor
mask
,
void
CARAFEForwardMLUKernelLauncher
(
const
Tensor
input
,
const
Tensor
mask
,
Tensor
rinput
,
Tensor
routput
,
Tensor
rmask
,
Tensor
rinput
,
Tensor
routput
,
Tensor
rmask
,
Tensor
output
,
const
int
kernel_size
,
Tensor
output
,
const
int
kernel_size
,
const
int
group_size
,
const
int
group_size
,
const
int
scale_factor
)
{
const
int
scale_factor
)
{
const
int
batch_size
=
output
.
size
(
0
);
const
int
channels
=
output
.
size
(
1
);
const
int
ho
=
output
.
size
(
2
);
const
int
wo
=
output
.
size
(
3
);
// check tensor data type
// check tensor data type
TORCH_CHECK
(
TORCH_CHECK
(
input
.
scalar_type
()
==
at
::
kFloat
||
input
.
scalar_type
()
==
at
::
kHalf
,
input
.
scalar_type
()
==
at
::
kFloat
||
input
.
scalar_type
()
==
at
::
kHalf
,
...
@@ -221,37 +34,10 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask,
...
@@ -221,37 +34,10 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask,
// return fast on zero-element tensor
// return fast on zero-element tensor
if
(
output
.
numel
()
==
0
)
{
if
(
output
.
numel
()
==
0
)
{
output
=
at
::
zeros
(
{
batch_size
,
channels
,
ho
,
wo
}
,
output
.
options
());
output
=
at
::
zeros
(
output
.
sizes
().
vec
()
,
output
.
options
());
return
;
return
;
}
}
// set param
CarafeForwardParam
param
;
param
.
N
=
input
.
size
(
0
);
param
.
Ci
=
input
.
size
(
1
);
param
.
Hi
=
input
.
size
(
2
);
param
.
Wi
=
input
.
size
(
3
);
param
.
kernel_size
=
kernel_size
;
param
.
group_size
=
group_size
;
param
.
scale_factor
=
scale_factor
;
param
.
Cg
=
param
.
Ci
/
group_size
;
param
.
dtype_size
=
input
.
itemsize
();
param
.
align_size_NRAM
=
NRAM_ALIGN_SIZE
/
param
.
dtype_size
;
param
.
align_size_NFU
=
NFU_ALIGN_SIZE
/
param
.
dtype_size
;
param
.
kernel_size_sq
=
param
.
kernel_size
*
param
.
kernel_size
;
param
.
kernel_size_half
=
(
param
.
kernel_size
-
1
)
/
2
;
param
.
Ho
=
param
.
Hi
*
param
.
scale_factor
;
param
.
Wo
=
param
.
Wi
*
param
.
scale_factor
;
// generate policy
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
CarafeForwardBlockDim
block_dim
;
CarafeForwardGridDim
grid_dim
;
genPolicyForward
(
&
param
,
&
block_dim
,
&
grid_dim
,
&
k_dim
,
&
k_type
);
// convert NCHW to NHWC
// convert NCHW to NHWC
auto
memory_format_input_nhwc
=
auto
memory_format_input_nhwc
=
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
input
.
dim
());
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
input
.
dim
());
...
@@ -268,6 +54,12 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask,
...
@@ -268,6 +54,12 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask,
auto
routput_
=
auto
routput_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
output
,
memory_format_output_nhwc
);
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
output
,
memory_format_output_nhwc
);
// set tensor descriptor
MluOpTensorDescriptor
input_desc
,
mask_desc
,
output_desc
;
input_desc
.
set_with_layout
(
rinput_
,
MLUOP_LAYOUT_NHWC
);
mask_desc
.
set_with_layout
(
rmask_
,
MLUOP_LAYOUT_NHWC
);
output_desc
.
set_with_layout
(
routput_
,
MLUOP_LAYOUT_NHWC
);
// get ptr of tensors
// get ptr of tensors
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
rinput_
);
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
rinput_
);
auto
input_ptr
=
input_impl
->
cnnlMalloc
();
auto
input_ptr
=
input_impl
->
cnnlMalloc
();
...
@@ -276,45 +68,29 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask,
...
@@ -276,45 +68,29 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask,
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
routput_
);
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
routput_
);
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
//
g
et
compute queue
//
s
et
op descriptor
auto
queue
=
torch_mlu
::
getCurQueu
e
();
auto
handle
=
mluOpGetCurrentHandl
e
();
mluOpCarafeDescriptor_t
carafe_desc
;
// get dtype of input
TORCH_MLUOP_CHECK
(
mluOpCreateCarafeDescriptor
(
&
carafe_desc
));
cnrtDataType_t
d_type
=
torch_mlu
::
toCnrtDtype
(
input
.
dtype
());
TORCH_MLUOP_CHECK
(
mluOpSetCarafeDescriptor
(
carafe_desc
,
input
.
dim
(),
kernel_size
,
group_size
,
scale_factor
));
// launch kernel
// launch kernel
auto
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
TORCH_MLUOP_CHECK
(
mluOpCarafeForward
(
handle
,
carafe_desc
,
input_desc
.
desc
(),
CNLOG
(
INFO
)
<<
"Launch Kernel KernelCarafeForward<<<Union"
input_ptr
,
mask_desc
.
desc
(),
mask_ptr
,
<<
k_type
/
core_dim
<<
", "
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
output_desc
.
desc
(),
output_ptr
));
<<
k_dim
.
z
<<
">>>"
;
// destroy op descriptor
TORCH_MLUOP_CHECK
(
mluOpDestroyCarafeDescriptor
(
carafe_desc
));
KernelCarafeForward
(
k_dim
,
k_type
,
queue
,
d_type
,
input_ptr
,
mask_ptr
,
param
,
block_dim
,
grid_dim
,
output_ptr
);
// copy output from NHWC back into NCHW
// copy output from NHWC back into NCHW
rinput
.
copy_
(
rinput_
);
rinput
.
copy_
(
rinput_
);
output
.
copy_
(
routput_
);
output
.
copy_
(
routput_
);
}
}
// Policy Function for Backward
static
void
policyFuncBackward
(
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
)
{
// set Union1 Job
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
k_dim
->
x
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
k_dim
->
y
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
k_dim
->
z
=
1
;
}
void
CARAFEBackwardMLUKernelLauncher
(
void
CARAFEBackwardMLUKernelLauncher
(
const
Tensor
grad_output
,
const
Tensor
rinput
,
const
Tensor
mask
,
const
Tensor
grad_output
,
const
Tensor
rinput
,
const
Tensor
mask
,
Tensor
rgrad_output
,
Tensor
rgrad_input_hs
,
Tensor
rgrad_input
,
Tensor
rgrad_output
,
Tensor
rgrad_input_hs
,
Tensor
rgrad_input
,
Tensor
rgrad_mask
,
Tensor
grad_input
,
Tensor
grad_mask
,
Tensor
rgrad_mask
,
Tensor
grad_input
,
Tensor
grad_mask
,
const
int
kernel_size
,
const
int
group_size
,
const
int
scale_factor
)
{
const
int
kernel_size
,
const
int
group_size
,
const
int
scale_factor
)
{
const
int
batch_size
=
rinput
.
size
(
0
);
const
int
channels
=
rinput
.
size
(
1
);
const
int
hi
=
rinput
.
size
(
2
);
const
int
wi
=
rinput
.
size
(
3
);
// data type check
// data type check
TORCH_CHECK
(
grad_output
.
scalar_type
()
==
at
::
kFloat
||
TORCH_CHECK
(
grad_output
.
scalar_type
()
==
at
::
kFloat
||
grad_output
.
scalar_type
()
==
at
::
kHalf
,
grad_output
.
scalar_type
()
==
at
::
kHalf
,
...
@@ -331,11 +107,6 @@ void CARAFEBackwardMLUKernelLauncher(
...
@@ -331,11 +107,6 @@ void CARAFEBackwardMLUKernelLauncher(
TORCH_CHECK
(
kernel_size
<
137
,
"kernel_size should be less than 137, got "
,
TORCH_CHECK
(
kernel_size
<
137
,
"kernel_size should be less than 137, got "
,
kernel_size
);
kernel_size
);
// set task dimension
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
policyFuncBackward
(
&
k_dim
,
&
k_type
);
// convert NCHW to NHWC
// convert NCHW to NHWC
auto
memory_format_input_nhwc
=
auto
memory_format_input_nhwc
=
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
rinput
.
dim
());
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
rinput
.
dim
());
...
@@ -363,8 +134,15 @@ void CARAFEBackwardMLUKernelLauncher(
...
@@ -363,8 +134,15 @@ void CARAFEBackwardMLUKernelLauncher(
auto
rgrad_mask_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
auto
rgrad_mask_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
grad_mask
,
memory_format_grad_mask_nhwc
);
grad_mask
,
memory_format_grad_mask_nhwc
);
// get compute queue
// set tensor descriptor
auto
queue
=
torch_mlu
::
getCurQueue
();
MluOpTensorDescriptor
input_desc
,
mask_desc
;
input_desc
.
set_with_layout
(
rinput_
,
MLUOP_LAYOUT_NHWC
);
mask_desc
.
set_with_layout
(
rmask_
,
MLUOP_LAYOUT_NHWC
);
MluOpTensorDescriptor
grad_output_desc
,
grad_input_desc
,
grad_mask_desc
;
grad_output_desc
.
set_with_layout
(
rgrad_output_
,
MLUOP_LAYOUT_NHWC
);
grad_input_desc
.
set_with_layout
(
rgrad_input_
,
MLUOP_LAYOUT_NHWC
);
grad_mask_desc
.
set_with_layout
(
rgrad_mask_
,
MLUOP_LAYOUT_NHWC
);
// get ptr of tensors
// get ptr of tensors
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
rinput_
);
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
rinput_
);
...
@@ -378,19 +156,20 @@ void CARAFEBackwardMLUKernelLauncher(
...
@@ -378,19 +156,20 @@ void CARAFEBackwardMLUKernelLauncher(
auto
grad_mask_impl
=
torch_mlu
::
getMluTensorImpl
(
rgrad_mask_
);
auto
grad_mask_impl
=
torch_mlu
::
getMluTensorImpl
(
rgrad_mask_
);
auto
grad_mask_ptr
=
grad_mask_impl
->
cnnlMalloc
();
auto
grad_mask_ptr
=
grad_mask_impl
->
cnnlMalloc
();
// get dtype of grad_output
// set op descriptor
cnrtDataType_t
d_type
=
torch_mlu
::
toCnrtDtype
(
grad_output
.
dtype
());
auto
handle
=
mluOpGetCurrentHandle
();
auto
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
mluOpCarafeDescriptor_t
carafe_desc
;
TORCH_MLUOP_CHECK
(
mluOpCreateCarafeDescriptor
(
&
carafe_desc
));
CNLOG
(
INFO
)
<<
"Launch Kernel KernelCarafeBackward<<<Union"
TORCH_MLUOP_CHECK
(
mluOpSetCarafeDescriptor
(
<<
k_type
/
core_dim
<<
", "
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
carafe_desc
,
grad_output
.
dim
(),
kernel_size
,
group_size
,
scale_factor
));
<<
k_dim
.
z
<<
">>>"
;
// launch kernel
// launch kernel
KernelCarafeBackward
(
k_dim
,
k_type
,
queue
,
d_type
,
input_ptr
,
mask_ptr
,
TORCH_MLUOP_CHECK
(
mluOpCarafeBackward
(
grad_output_ptr
,
grad_input_ptr
,
grad_mask_ptr
,
handle
,
carafe_desc
,
input_desc
.
desc
(),
input_ptr
,
mask_desc
.
desc
(),
batch_size
,
hi
,
wi
,
channels
,
kernel_size
,
group_size
,
mask_ptr
,
grad_output_desc
.
desc
(),
grad_output_ptr
,
scale_factor
);
grad_input_desc
.
desc
(),
grad_input_ptr
,
grad_mask_desc
.
desc
(),
grad_mask_ptr
));
// destroy op descriptor
TORCH_MLUOP_CHECK
(
mluOpDestroyCarafeDescriptor
(
carafe_desc
));
// copy output from NHWC back into NCHW
// copy output from NHWC back into NCHW
grad_input
.
copy_
(
rgrad_input_
);
grad_input
.
copy_
(
rgrad_input_
);
...
...
mmcv/ops/csrc/pytorch/mlu/deform_roi_pool_mlu.cpp
View file @
91da9643
...
@@ -9,254 +9,59 @@
...
@@ -9,254 +9,59 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "mlu_common_helper.h"
#include "pytorch_mlu_helper.hpp"
void
KernelDeformRoIPoolForward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
cnrtDataType_t
data_type
,
const
void
*
input
,
const
void
*
rois
,
const
void
*
offset
,
void
*
output
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
num_rois
,
const
int
pooled_height
,
const
int
pooled_width
,
const
float
spatial_scale
,
const
int
sampling_ratio
,
const
float
gamma
);
void
KernelDeformRoIPoolBackward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
cnrtDataType_t
data_type
,
const
void
*
grad_output
,
const
void
*
input
,
const
void
*
rois
,
const
void
*
offset
,
void
*
grad_input
,
void
*
grad_offset
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
num_rois
,
const
int
pooled_height
,
const
int
pooled_width
,
const
float
spatial_scale
,
const
int
sampling_ratio
,
const
float
gamma
);
// policy function for forward and backward
static
void
policyFunc
(
const
int
bin_num
,
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
)
{
const
size_t
cluster_limit
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
;
const
size_t
core_limit
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
const
size_t
bin_num_align
=
CEIL_ALIGN
(
bin_num
,
core_limit
);
k_dim
->
x
=
core_limit
;
k_dim
->
y
=
(
bin_num_align
/
core_limit
)
>
cluster_limit
?
cluster_limit
:
(
bin_num_align
/
core_limit
);
k_dim
->
z
=
1
;
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
}
void
DeformRoIPoolForwardMLUKernelLauncher
(
Tensor
input
,
Tensor
rois
,
void
DeformRoIPoolForwardMLUKernelLauncher
(
Tensor
input
,
Tensor
rois
,
Tensor
offset
,
Tensor
output
,
Tensor
offset
,
Tensor
output
,
int
pooled_height
,
int
pooled_width
,
int
pooled_height
,
int
pooled_width
,
float
spatial_scale
,
float
spatial_scale
,
int
sampling_ratio
,
float
gamma
)
{
int
sampling_ratio
,
float
gamma
)
{
// Check dtype.
TORCH_CHECK
(
input
.
scalar_type
()
==
at
::
kFloat
||
input
.
scalar_type
()
==
at
::
kHalf
,
"input type should be Float or Half, got "
,
input
.
scalar_type
());
TORCH_CHECK
(
input
.
scalar_type
()
==
rois
.
scalar_type
(),
"rois should have the same type as input"
);
// Check shape.
TORCH_CHECK
(
input
.
dim
()
==
4
,
"input should be 4d tensor, got "
,
input
.
dim
(),
"D."
);
TORCH_CHECK
(
rois
.
dim
()
==
2
,
"rois should be 2d tensor, got "
,
rois
.
dim
(),
"D."
);
if
(
offset
.
defined
()
&&
offset
.
numel
()
>
0
)
{
TORCH_CHECK
(
input
.
scalar_type
()
==
offset
.
scalar_type
(),
"offset should have the same type as input"
);
TORCH_CHECK
(
offset
.
dim
()
==
4
,
"offset should be 4d tensor, got "
,
offset
.
dim
(),
"D."
);
TORCH_CHECK
(
(
offset
.
size
(
0
)
==
rois
.
size
(
0
)),
"offset.size(0) = "
,
offset
.
size
(
0
),
"while rois.size(0)) = "
,
rois
.
size
(
0
),
". They should be the same."
);
TORCH_CHECK
((
offset
.
size
(
1
)
==
2
),
"offset.size(1) should be 2, "
,
"but now offset.size(1) = "
,
offset
.
size
(
1
),
"."
);
TORCH_CHECK
((
offset
.
size
(
2
)
==
output
.
size
(
2
)),
"offset.size(2) = "
,
offset
.
size
(
2
),
"while output.size(2)) = "
,
output
.
size
(
2
),
". They should be the same."
);
TORCH_CHECK
((
offset
.
size
(
3
)
==
output
.
size
(
3
)),
"offset.size(3) = "
,
offset
.
size
(
3
),
"while output.size(3)) = "
,
output
.
size
(
3
),
". They should be the same."
);
}
TORCH_CHECK
(
spatial_scale
>
0
&&
spatial_scale
<=
1
,
"spatial_scale should be within (0, 1], got "
,
spatial_scale
,
"."
);
// compute kernel params
auto
height
=
input
.
size
(
2
);
auto
width
=
input
.
size
(
3
);
auto
channels
=
input
.
size
(
1
);
auto
num_rois
=
output
.
size
(
0
);
if
(
output
.
numel
()
==
0
)
{
output
=
at
::
zeros
({
num_rois
,
channels
,
pooled_height
,
pooled_width
},
input
.
options
());
return
;
}
// zero element check
TORCH_CHECK
(
input
.
size
(
0
)
!=
0
,
"input.size(0) should not be zero, got "
,
input
.
size
(
0
));
TORCH_CHECK
(
rois
.
numel
()
!=
0
,
"rois.numel() should not be zero, got "
,
rois
.
numel
());
if
(
input
.
numel
()
==
0
||
output
.
numel
()
==
0
)
{
return
;
}
// large tensor check
const
size_t
max_input_num
=
2147483648
;
// 2^31, 2G num
TORCH_CHECK
(
input
.
numel
()
<
max_input_num
,
"input.numel() should be less than 2147483648, got "
,
input
.
numel
());
TORCH_CHECK
(
rois
.
numel
()
<
max_input_num
,
"rois.numel() should be less than 2147483648, got "
,
rois
.
numel
());
TORCH_CHECK
(
output
.
numel
()
<
max_input_num
,
"output.numel() should be less than 2147483648, got "
,
output
.
numel
());
TORCH_CHECK
(
!
offset
.
defined
()
||
offset
.
numel
()
<
max_input_num
,
"offset.numel() should be less than 2147483648, got "
,
offset
.
numel
());
auto
memory_format
=
auto
memory_format
=
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
input
.
dim
());
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
input
.
dim
());
auto
input_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
input
,
memory_format
);
auto
input_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
input
,
memory_format
);
auto
rois_contiguous
=
at
::
Tensor
output_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
rois
,
rois
.
suggest_memory_format
());
at
::
empty
({
num_rois
,
channels
,
pooled_height
,
pooled_width
},
auto
output_contiguous
=
input
.
options
(),
memory_format
);
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
output
,
memory_format
);
// calculate task dimension
MluOpTensorDescriptor
input_desc
,
rois_desc
,
offset_desc
,
output_desc
;
cnrtDim3_t
k_dim
;
input_desc
.
set_with_layout
(
input_
,
MLUOP_LAYOUT_NHWC
);
cnrtFunctionType_t
k_type
;
rois_desc
.
set
(
rois_contiguous
);
policyFunc
(
num_rois
*
pooled_height
*
pooled_width
,
&
k_dim
,
&
k_type
);
output_desc
.
set_with_layout
(
output_contiguous
,
MLUOP_LAYOUT_NHWC
);
// get compute queue
mluOpTensorDescriptor_t
offset_real_desc
=
NULL
;
auto
queue
=
torch_mlu
::
getCurQueue
();
void
*
offset_ptr
=
NULL
;
if
(
offset
.
defined
()
&&
offset
.
numel
()
>
0
)
{
auto
offset_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
offset
,
offset
.
suggest_memory_format
());
offset_desc
.
set
(
offset_contiguous
);
offset_real_desc
=
offset_desc
.
desc
();
auto
offset_impl
=
torch_mlu
::
getMluTensorImpl
(
offset_contiguous
);
offset_ptr
=
offset_impl
->
cnnlMalloc
();
}
// get ptr of tensors
// get ptr of tensors
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
input_
);
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
input_
);
auto
input_ptr
=
input_impl
->
cnnlMalloc
();
auto
input_ptr
=
input_impl
->
cnnlMalloc
();
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois
);
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois
_contiguous
);
auto
rois_ptr
=
rois_impl
->
cnnlMalloc
();
auto
rois_ptr
=
rois_impl
->
cnnlMalloc
();
auto
offset_impl
=
torch_mlu
::
getMluTensorImpl
(
offset
);
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output_contiguous
);
auto
offset_ptr
=
offset_impl
->
cnnlMalloc
();
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output_
);
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
// get comput dtype of input
// get compute handle
cnrtDataType_t
data_type
=
torch_mlu
::
toCnrtDtype
(
input_
.
dtype
());
auto
handle
=
mluOpGetCurrentHandle
();
TORCH_MLUOP_CHECK
(
mluOpDeformRoiPoolForward
(
handle
,
input_desc
.
desc
(),
input_ptr
,
rois_desc
.
desc
(),
rois_ptr
,
offset_real_desc
,
offset_ptr
,
pooled_height
,
pooled_width
,
spatial_scale
,
sampling_ratio
,
gamma
,
output_desc
.
desc
(),
output_ptr
));
// launch kernel
output
.
copy_
(
output_contiguous
);
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernelDeformRoIPoolForward<<<"
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
KernelDeformRoIPoolForward
(
k_dim
,
k_type
,
queue
,
data_type
,
input_ptr
,
rois_ptr
,
offset_ptr
,
output_ptr
,
channels
,
height
,
width
,
num_rois
,
pooled_height
,
pooled_width
,
spatial_scale
,
sampling_ratio
,
gamma
);
output
.
copy_
(
output_
);
}
}
void
DeformRoIPoolBackwardMLUKernelLauncher
(
void
DeformRoIPoolBackwardMLUKernelLauncher
(
Tensor
grad_output
,
Tensor
input
,
Tensor
rois
,
Tensor
offset
,
Tensor
grad_output
,
Tensor
input
,
Tensor
rois
,
Tensor
offset
,
Tensor
grad_input
,
Tensor
grad_offset
,
int
pooled_height
,
int
pooled_width
,
Tensor
grad_input
,
Tensor
grad_offset
,
int
pooled_height
,
int
pooled_width
,
float
spatial_scale
,
int
sampling_ratio
,
float
gamma
)
{
float
spatial_scale
,
int
sampling_ratio
,
float
gamma
)
{
// Check dtype.
TORCH_CHECK
(
input
.
scalar_type
()
==
at
::
kFloat
||
input
.
scalar_type
()
==
at
::
kHalf
,
"input type should be Float or Half, got "
,
input
.
scalar_type
());
TORCH_CHECK
(
input
.
scalar_type
()
==
grad_output
.
scalar_type
(),
"grad_output should have the same type as input"
);
TORCH_CHECK
(
input
.
scalar_type
()
==
rois
.
scalar_type
(),
"rois should have the same type as input"
);
TORCH_CHECK
(
input
.
scalar_type
()
==
grad_input
.
scalar_type
(),
"grad_input should have the same type as input"
);
// Check shape.
TORCH_CHECK
(
grad_output
.
dim
()
==
4
,
"grad_output should be 4d tensor, got "
,
grad_output
.
dim
(),
"D."
);
TORCH_CHECK
(
input
.
dim
()
==
4
,
"input should be 4d tensor, got "
,
input
.
dim
(),
"D."
);
TORCH_CHECK
(
rois
.
dim
()
==
2
,
"rois should be 2d tensor, got "
,
rois
.
dim
(),
"D."
);
if
(
offset
.
defined
()
&&
offset
.
numel
()
>
0
)
{
TORCH_CHECK
(
input
.
scalar_type
()
==
offset
.
scalar_type
(),
"offset should have the same type as input"
);
TORCH_CHECK
(
offset
.
dim
()
==
4
,
"offset should be 4d tensor, got "
,
offset
.
dim
(),
"D."
);
TORCH_CHECK
(
(
offset
.
size
(
0
)
==
rois
.
size
(
0
)),
"offset.size(0) = "
,
offset
.
size
(
0
),
"while rois.size(0)) = "
,
rois
.
size
(
0
),
". They should be the same."
);
TORCH_CHECK
((
offset
.
size
(
1
)
==
2
),
"offset.size(1) should be 2, "
,
"but now offset.size(1) = "
,
offset
.
size
(
1
),
"."
);
TORCH_CHECK
((
offset
.
size
(
2
)
==
grad_output
.
size
(
2
)),
"offset.size(2) = "
,
offset
.
size
(
2
),
"while grad_output.size(2)) = "
,
grad_output
.
size
(
2
),
". They should be the same."
);
TORCH_CHECK
((
offset
.
size
(
3
)
==
grad_output
.
size
(
3
)),
"offset.size(3) = "
,
offset
.
size
(
3
),
"while grad_output.size(3)) = "
,
grad_output
.
size
(
3
),
". They should be the same."
);
}
TORCH_CHECK
(
spatial_scale
>
0
&&
spatial_scale
<=
1
,
"spatial_scale should be within (0, 1], got "
,
spatial_scale
);
// Check relationship between tensor.
TORCH_CHECK
((
grad_output
.
size
(
0
)
==
rois
.
size
(
0
)),
"grad_output.size(0) = "
,
grad_output
.
size
(
0
),
"while rois.size(0)) = "
,
rois
.
size
(
0
),
". They should be the same."
);
TORCH_CHECK
((
grad_output
.
size
(
1
)
==
input
.
size
(
1
)),
"grad_output.size(1) = "
,
grad_output
.
size
(
1
),
"while input.size(1)) = "
,
input
.
size
(
1
),
". They should be the same."
);
TORCH_CHECK
((
grad_output
.
size
(
2
)
==
pooled_height
),
"grad_output.size(2) = "
,
grad_output
.
size
(
2
),
"while pooled_height = "
,
pooled_height
,
". They should be the same."
);
TORCH_CHECK
((
grad_output
.
size
(
3
)
==
pooled_width
),
"grad_output.size(3) = "
,
grad_output
.
size
(
3
),
"while pooled_width = "
,
pooled_width
,
". They should be the same."
);
// compute kernel params
auto
batch
=
input
.
size
(
0
);
auto
channels
=
input
.
size
(
1
);
auto
height
=
input
.
size
(
2
);
auto
width
=
input
.
size
(
3
);
auto
num_rois
=
grad_output
.
size
(
0
);
// zero element check
TORCH_CHECK
(
input
.
size
(
0
)
!=
0
,
"input.size(0) should not be zero, got "
,
input
.
size
(
0
));
TORCH_CHECK
(
rois
.
numel
()
!=
0
,
"rois.numel() should not be zero, got "
,
rois
.
numel
());
if
(
input
.
numel
()
==
0
||
grad_output
.
numel
()
==
0
)
{
return
;
}
// large tensor check
const
size_t
max_input_num
=
2147483648
;
// 2^31, 2G num
TORCH_CHECK
(
input
.
numel
()
<
max_input_num
,
"input.numel() should be less than 2147483648, got "
,
input
.
numel
());
TORCH_CHECK
(
rois
.
numel
()
<
max_input_num
,
"rois.numel() should be less than 2147483648, got "
,
rois
.
numel
());
TORCH_CHECK
(
grad_output
.
numel
()
<
max_input_num
,
"grad_output.numel() should be less than 2147483648, got "
,
grad_output
.
numel
());
TORCH_CHECK
(
!
offset
.
defined
()
||
offset
.
numel
()
<
max_input_num
,
"offset.numel() should be less than 2147483648, got "
,
offset
.
numel
());
auto
memory_format
=
auto
memory_format
=
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
grad_output
.
dim
());
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
grad_output
.
dim
());
auto
grad_output_
=
auto
grad_output_
=
...
@@ -264,45 +69,56 @@ void DeformRoIPoolBackwardMLUKernelLauncher(
...
@@ -264,45 +69,56 @@ void DeformRoIPoolBackwardMLUKernelLauncher(
memory_format
=
memory_format
=
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
input
.
dim
());
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
input
.
dim
());
auto
input_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
input
,
memory_format
);
auto
input_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
input
,
memory_format
);
at
::
Tensor
grad_input_
=
at
::
empty
({
batch
,
channels
,
height
,
width
},
auto
rois_contiguous
=
input
.
options
(),
memory_format
)
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
rois
,
rois
.
suggest_memory_format
());
.
zero_
();
auto
grad_input_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
grad_input
,
memory_format
);
// calculate task dimension
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
policyFunc
(
num_rois
*
pooled_height
*
pooled_width
,
&
k_dim
,
&
k_type
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// get ptr of tensors
// get ptr of tensors
auto
grad_output_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_output_
);
auto
grad_output_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_output_
);
auto
grad_output_ptr
=
grad_output_impl
->
cnnlMalloc
();
auto
grad_output_ptr
=
grad_output_impl
->
cnnlMalloc
();
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
input_
);
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
input_
);
auto
input_ptr
=
input_impl
->
cnnlMalloc
();
auto
input_ptr
=
input_impl
->
cnnlMalloc
();
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois
);
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois
_contiguous
);
auto
rois_ptr
=
rois_impl
->
cnnlMalloc
();
auto
rois_ptr
=
rois_impl
->
cnnlMalloc
();
auto
offset_impl
=
torch_mlu
::
getMluTensorImpl
(
offset
);
auto
offset_ptr
=
offset_impl
->
cnnlMalloc
();
auto
grad_input_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_input_
);
auto
grad_input_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_input_
);
auto
grad_input_ptr
=
grad_input_impl
->
cnnlMalloc
();
auto
grad_input_ptr
=
grad_input_impl
->
cnnlMalloc
();
auto
grad_offset_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_offset
);
auto
grad_offset_ptr
=
grad_offset_impl
->
cnnlMalloc
();
// get comput dtype of input
cnrtDataType_t
data_type
=
torch_mlu
::
toCnrtDtype
(
input
.
dtype
());
// launch kernel
MluOpTensorDescriptor
grad_output_desc
,
input_desc
,
rois_desc
,
offset_desc
,
CNLOG
(
INFO
)
<<
"Launch Kernel KernelDeformRoIPoolBackward<<<"
<<
k_dim
.
x
grad_input_desc
,
grad_offset_desc
;
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
grad_output_desc
.
set_with_layout
(
grad_output_
,
MLUOP_LAYOUT_NHWC
);
input_desc
.
set_with_layout
(
input_
,
MLUOP_LAYOUT_NHWC
);
KernelDeformRoIPoolBackward
(
k_dim
,
k_type
,
queue
,
data_type
,
grad_output_ptr
,
rois_desc
.
set
(
rois_contiguous
);
input_ptr
,
rois_ptr
,
offset_ptr
,
grad_input_ptr
,
grad_input_desc
.
set_with_layout
(
grad_input_
,
MLUOP_LAYOUT_NHWC
);
grad_offset_ptr
,
channels
,
height
,
width
,
mluOpTensorDescriptor_t
offset_real_desc
=
NULL
;
num_rois
,
pooled_height
,
pooled_width
,
void
*
offset_ptr
=
NULL
;
spatial_scale
,
sampling_ratio
,
gamma
);
if
(
offset
.
defined
()
&&
offset
.
numel
()
>
0
)
{
auto
offset_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
offset
,
offset
.
suggest_memory_format
());
offset_desc
.
set
(
offset_contiguous
);
offset_real_desc
=
offset_desc
.
desc
();
auto
offset_impl
=
torch_mlu
::
getMluTensorImpl
(
offset_contiguous
);
offset_ptr
=
offset_impl
->
cnnlMalloc
();
}
mluOpTensorDescriptor_t
grad_offset_real_desc
=
NULL
;
void
*
grad_offset_ptr
=
NULL
;
if
(
grad_offset
.
defined
()
&&
grad_offset
.
numel
()
>
0
)
{
auto
grad_offset_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
grad_offset
,
grad_offset
.
suggest_memory_format
());
grad_offset_desc
.
set
(
grad_offset_contiguous
);
grad_offset_real_desc
=
grad_offset_desc
.
desc
();
auto
grad_offset_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_offset_contiguous
);
grad_offset_ptr
=
grad_offset_impl
->
cnnlMalloc
();
}
// get compute handle
auto
handle
=
mluOpGetCurrentHandle
();
TORCH_MLUOP_CHECK
(
mluOpDeformRoiPoolBackward
(
handle
,
grad_output_desc
.
desc
(),
grad_output_ptr
,
input_desc
.
desc
(),
input_ptr
,
rois_desc
.
desc
(),
rois_ptr
,
offset_real_desc
,
offset_ptr
,
pooled_height
,
pooled_width
,
spatial_scale
,
sampling_ratio
,
gamma
,
grad_input_desc
.
desc
(),
grad_input_ptr
,
grad_offset_real_desc
,
grad_offset_ptr
));
grad_input
.
copy_
(
grad_input_
);
grad_input
.
copy_
(
grad_input_
);
}
}
...
...
mmcv/ops/csrc/pytorch/mlu/diff_iou_rotated_mlu.cpp
0 → 100644
View file @
91da9643
/*************************************************************************
* Copyright (C) 2023 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"
Tensor
diff_iou_rotated_sort_vertices_forward_mlu
(
Tensor
vertices
,
Tensor
mask
,
Tensor
num_valid
)
{
// params check
TORCH_CHECK
(
vertices
.
scalar_type
()
==
at
::
kFloat
,
"vertices type should be Float, got "
,
vertices
.
scalar_type
());
TORCH_CHECK
(
mask
.
scalar_type
()
==
at
::
kBool
,
"mask should be Bool, got "
,
mask
.
scalar_type
());
TORCH_CHECK
(
num_valid
.
scalar_type
()
==
at
::
kInt
,
"num_valid type should be Int32, got "
,
num_valid
.
scalar_type
());
TORCH_CHECK
(
vertices
.
size
(
2
)
==
24
,
"vertices.dim(2) should be 24, got "
,
vertices
.
size
(
2
));
TORCH_CHECK
(
mask
.
size
(
2
)
==
24
,
"mask.dim(2) should be 24, got "
,
mask
.
size
(
2
));
// zero-element check
if
(
vertices
.
numel
()
==
0
)
{
return
at
::
empty
({
0
},
num_valid
.
options
().
dtype
(
at
::
kInt
));
}
auto
idx
=
at
::
empty
({
vertices
.
size
(
0
),
vertices
.
size
(
1
),
9
},
num_valid
.
options
().
dtype
(
at
::
kInt
));
INITIAL_MLU_PARAM_WITH_TENSOR
(
vertices
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
mask
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
num_valid
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
idx
);
// get compute handle
auto
handle
=
mluOpGetCurrentHandle
();
// launch kernel
TORCH_MLUOP_CHECK
(
mluOpDiffIouRotatedSortVerticesForward
(
handle
,
vertices_desc
.
desc
(),
vertices_ptr
,
mask_desc
.
desc
(),
mask_ptr
,
num_valid_desc
.
desc
(),
num_valid_ptr
,
idx_desc
.
desc
(),
idx_ptr
));
return
idx
;
}
Tensor
diff_iou_rotated_sort_vertices_forward_impl
(
Tensor
vertices
,
Tensor
mask
,
Tensor
num_valid
);
REGISTER_DEVICE_IMPL
(
diff_iou_rotated_sort_vertices_forward_impl
,
MLU
,
diff_iou_rotated_sort_vertices_forward_mlu
);
mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp
View file @
91da9643
...
@@ -12,87 +12,11 @@
...
@@ -12,87 +12,11 @@
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "pytorch_device_registry.hpp"
#include "mlu_common_helper.h"
#include "pytorch_mlu_helper.hpp"
void
KernelFocalLossSigmoidForward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
void
sigmoid_focal_loss_forward_mlu
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
cnrtQueue_t
queue
,
Tensor
output
,
const
float
gamma
,
const
cnrtDataType_t
d_type
,
const
float
alpha
)
{
const
void
*
input
,
const
void
*
target
,
const
void
*
weight
,
const
int32_t
N
,
const
int32_t
C
,
const
float
alpha
,
const
float
gamma
,
void
*
output
);
void
KernelFocalLossSigmoidBackward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
void
*
input
,
const
void
*
target
,
const
void
*
weight
,
const
float
gamma
,
const
float
alpha
,
const
int32_t
dim_n
,
const
int32_t
deal_n
,
const
int32_t
dim_c
,
void
*
output
);
// Policy Function for Forward
static
void
policyFuncForward
(
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
,
const
Tensor
&
input
,
const
Tensor
&
target
,
const
Tensor
&
weight
)
{
auto
N
=
input
.
size
(
0
);
auto
C
=
input
.
size
(
1
);
const
size_t
nram_size
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrNramSizePerMcore
);
const
size_t
c_align_size
=
PAD_UP
((
C
*
input
.
itemsize
()),
NFU_ALIGN_SIZE
);
const
int
split_target_num
=
2
;
const
int
split_pipeline_num
=
6
;
const
int
has_weight
=
weight
.
data_ptr
()
!=
nullptr
;
const
int
target_data_width
=
target
.
scalar_type
()
==
at
::
kLong
?
target
.
itemsize
()
/
2
:
target
.
itemsize
();
const
int
threshold_c
=
PAD_DOWN
((
nram_size
-
split_target_num
*
sizeof
(
int
))
/
(
split_pipeline_num
+
has_weight
),
NFU_ALIGN_SIZE
)
/
input
.
itemsize
();
int
n_seg
=
1
;
if
(
C
<=
threshold_c
)
{
int
c_size
=
C
*
input
.
itemsize
();
int
reservered_align_size
=
(
split_target_num
+
split_pipeline_num
)
*
NFU_ALIGN_SIZE
;
int
wegiht_size
=
0
;
if
(
has_weight
)
{
c_size
=
c_align_size
;
reservered_align_size
=
split_target_num
*
NFU_ALIGN_SIZE
;
wegiht_size
=
c_align_size
;
}
// n_seg * c_size * split_pipeline_num + n_seg * target.itemsize() *
// split_target_num
// + weight_size + reservered_align_size <= nram_size
n_seg
=
(
nram_size
-
wegiht_size
-
reservered_align_size
)
/
(
split_pipeline_num
*
c_size
+
split_target_num
*
sizeof
(
int32_t
));
}
auto
seg_num
=
n_seg
==
0
?
N
:
(
N
+
n_seg
-
1
)
/
n_seg
;
auto
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
auto
cluster_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
auto
core_num
=
core_dim
*
cluster_num
;
k_dim
->
x
=
*
k_type
;
k_dim
->
y
=
seg_num
>
core_num
?
cluster_num
:
(
seg_num
+
core_dim
-
1
)
/
core_dim
;
k_dim
->
z
=
1
;
}
// Policy Function for Backward
static
void
policyFuncBackward
(
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
)
{
// set Union1 Job
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
k_dim
->
x
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
k_dim
->
y
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
k_dim
->
z
=
1
;
}
void
SigmoidFocalLossForwardMLUKernelLauncher
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
Tensor
output
,
const
float
gamma
,
const
float
alpha
)
{
// params check
// params check
TORCH_CHECK
(
gamma
>=
0
,
"gamma should be greater than or equal to 0. "
,
TORCH_CHECK
(
gamma
>=
0
,
"gamma should be greater than or equal to 0. "
,
"But now gamma is "
,
gamma
,
"."
);
"But now gamma is "
,
gamma
,
"."
);
...
@@ -123,103 +47,50 @@ void SigmoidFocalLossForwardMLUKernelLauncher(Tensor input, Tensor target,
...
@@ -123,103 +47,50 @@ void SigmoidFocalLossForwardMLUKernelLauncher(Tensor input, Tensor target,
return
;
return
;
}
}
// calculate task dimension
// contiguous
cnrtDim3_t
k_dim
;
auto
input_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
cnrtFunctionType_t
k_type
=
CNRT_FUNC_TYPE_UNION1
;
input
,
input
.
suggest_memory_format
());
policyFuncForward
(
&
k_dim
,
&
k_type
,
input
,
target
,
weight
);
// target only support in32
auto
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
auto
target_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
target
.
toType
(
at
::
kInt
),
target
.
suggest_memory_format
());
// get compute queue
auto
weight_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
auto
queue
=
torch_mlu
::
getCurQueue
();
weight
,
weight
.
suggest_memory_format
());
auto
output_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
output
,
output
.
suggest_memory_format
());
// set tensor descriptor
MluOpTensorDescriptor
input_desc
,
target_desc
,
weight_desc
,
output_desc
;
input_desc
.
set
(
input_contiguous
);
target_desc
.
set
(
target_contiguous
);
weight_desc
.
set
(
weight_contiguous
);
output_desc
.
set
(
output_contiguous
);
// get ptr of tensors
// get ptr of tensors
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
input
);
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
input
_contiguous
);
auto
input_ptr
=
input_impl
->
cnnlMalloc
();
auto
input_ptr
=
input_impl
->
cnnlMalloc
();
auto
target_impl
=
torch_mlu
::
getMluTensorImpl
(
target
);
auto
target_impl
=
torch_mlu
::
getMluTensorImpl
(
target
_contiguous
);
auto
target_ptr
=
target_impl
->
cnnlMalloc
();
auto
target_ptr
=
target_impl
->
cnnlMalloc
();
auto
weight_impl
=
torch_mlu
::
getMluTensorImpl
(
weight
);
auto
weight_impl
=
torch_mlu
::
getMluTensorImpl
(
weight
_contiguous
);
auto
weight_ptr
=
weight_impl
->
cnnlMalloc
();
auto
weight_ptr
=
weight_impl
->
cnnlMalloc
();
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output
);
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output
_contiguous
);
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
// get dtype of input
// set prefer computation performance and redcuntion approach
cnrtDataType_t
d_type
=
torch_mlu
::
toCnrtDtype
(
input
.
dtype
());
mluOpComputationPreference_t
prefer
=
MLUOP_COMPUTATION_FAST
;
mluOpLossReduction_t
reduction
=
MLUOP_LOSS_REDUCTION_NONE
;
CNLOG
(
INFO
)
<<
"Launch Kernel KernelFocalLossSigmoidForward<<<Union"
<<
k_type
/
core_dim
<<
", "
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
// launch kernel
KernelFocalLossSigmoidForward
(
k_dim
,
k_type
,
queue
,
d_type
,
input_ptr
,
target_ptr
,
weight_ptr
,
input
.
size
(
0
),
input
.
size
(
1
),
alpha
,
gamma
,
output_ptr
);
}
void
getDealNAndThresholdC
(
const
int
compute_data_bytes
,
const
int
target_data_bytes
,
const
int
total_c
,
int
*
deal_n_ptr
,
int
*
threshold_c_ptr
,
const
bool
has_weight
,
const
bool
is_half
)
{
/* NRAM partition:
*
* |-----------------ping pong--------------------|
* |input | pt | alpha_t | temp | output | target | flt_min | gamma | weight|
*
* split_pipeline_num is 5: including input, pt, alpha_t, temp, output.
*/
const
int
nram_split_num
=
5
;
const
int
nram_split_pingpong
=
2
;
const
int
max_nram_size
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrNramSizePerMcore
);
int32_t
compute_align_size
=
NFU_ALIGN_SIZE
;
if
(
is_half
)
{
compute_align_size
+=
NFU_ALIGN_SIZE
;
}
const
int32_t
compute_align_num
=
compute_align_size
/
compute_data_bytes
;
// reservered_align_size: including input(ping pong), pt(ping pong),
// alpha_t(ping pong), temp(ping pong),
// output(ping pong), target(ping pong),
// flt_min and gamma.
const
int
reservered_align_size
=
((
nram_split_num
+
1
)
*
nram_split_pingpong
+
2
)
*
compute_align_size
;
int
nram_pingpong_size
=
max_nram_size
-
reservered_align_size
;
int
compute_c
=
total_c
;
auto
handle
=
mluOpGetCurrentHandle
();
int
threshold_c
=
0
;
if
(
has_weight
)
{
// reserved space for weight to align
nram_pingpong_size
-=
NFU_ALIGN_SIZE
;
// threshold_c * nram_split_pingpong * compute_data_bytes * nram_split_num +
// launch kernel
// nram_split_pingpong * target_data_bytes +
TORCH_MLUOP_CHECK
(
mluOpFocalLossSigmoidForward
(
// threshold_c * compute_data_bytes <= nram_pingpong_size
handle
,
prefer
,
reduction
,
input_desc
.
desc
(),
input_ptr
,
threshold_c
=
target_desc
.
desc
(),
target_ptr
,
weight_desc
.
desc
(),
weight_ptr
,
alpha
,
(
nram_pingpong_size
-
nram_split_pingpong
*
target_data_bytes
)
/
gamma
,
output_desc
.
desc
(),
output_ptr
));
(
compute_data_bytes
*
(
nram_split_num
*
nram_split_pingpong
+
1
));
threshold_c
=
PAD_DOWN
(
threshold_c
,
compute_align_num
);
int
weight_space
=
PAD_UP
(
total_c
*
compute_data_bytes
,
NFU_ALIGN_SIZE
);
// reserved space for weight
nram_pingpong_size
-=
weight_space
;
compute_c
=
PAD_UP
(
total_c
,
compute_align_num
);
}
else
{
// threshold_c * nram_split_pingpong * compute_data_bytes * nram_split_num +
// nram_split_pingpong * target_data_bytes <= nram_pingpong_size
threshold_c
=
(
nram_pingpong_size
/
nram_split_pingpong
-
target_data_bytes
)
/
(
nram_split_num
*
compute_data_bytes
);
}
// deal_n * compute_c * nram_split_pingpong * compute_data_bytes *
// nram_split_num + deal_n * nram_split_pingpong * target_data_bytes <=
// nram_pingpong_size
*
deal_n_ptr
=
nram_pingpong_size
/
((
nram_split_num
*
compute_c
*
compute_data_bytes
+
target_data_bytes
)
*
nram_split_pingpong
);
*
threshold_c_ptr
=
threshold_c
;
}
}
void
SigmoidFocalLossBackwardMLUKernelLauncher
(
Tensor
input
,
Tensor
target
,
void
sigmoid_focal_loss_backward_mlu
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
Tensor
weight
,
Tensor
output
,
Tensor
output
,
const
float
gamma
,
const
float
gamma
,
const
float
alpha
)
{
const
float
alpha
)
{
// params check
// params check
TORCH_CHECK
(
gamma
>=
0
,
"gamma should be greater than or equal to 0. "
,
TORCH_CHECK
(
gamma
>=
0
,
"gamma should be greater than or equal to 0. "
,
"But now gamma is "
,
gamma
,
"."
);
"But now gamma is "
,
gamma
,
"."
);
...
@@ -246,77 +117,51 @@ void SigmoidFocalLossBackwardMLUKernelLauncher(Tensor input, Tensor target,
...
@@ -246,77 +117,51 @@ void SigmoidFocalLossBackwardMLUKernelLauncher(Tensor input, Tensor target,
CNLOG
(
INFO
)
<<
"weight is a empty tensor."
;
CNLOG
(
INFO
)
<<
"weight is a empty tensor."
;
}
}
auto
dim_c
=
input
.
size
(
1
);
const
int
compute_data_bytes
=
sizeof
(
float
);
// target supports only INT on MLU device while it keeps LONG on host side,
// so target.itemsize() / 2
const
int
target_data_bytes
=
target
.
scalar_type
()
==
at
::
kLong
?
(
target
.
itemsize
()
/
2
)
:
target
.
itemsize
();
int
deal_n
=
0
;
int
threshold_c
=
0
;
bool
is_half
=
false
;
if
(
input
.
scalar_type
()
==
at
::
kHalf
)
{
is_half
=
true
;
}
// calculate deal_n and threshold_c
getDealNAndThresholdC
(
compute_data_bytes
,
target_data_bytes
,
dim_c
,
&
deal_n
,
&
threshold_c
,
has_weight
,
is_half
);
// check C
TORCH_CHECK
(
threshold_c
>=
dim_c
,
"input.size(1) should be in the range of [0, "
,
threshold_c
,
"]. "
,
"But now input.size(1) is "
,
dim_c
,
"."
);
if
(
input
.
numel
()
==
0
||
target
.
numel
()
==
0
||
output
.
numel
()
==
0
)
{
if
(
input
.
numel
()
==
0
||
target
.
numel
()
==
0
||
output
.
numel
()
==
0
)
{
// return if zero-element
// return if zero-element
return
;
return
;
}
}
// set task dimension
// contiguous
cnrtDim3_t
k_dim
;
auto
input_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
cnrtFunctionType_t
k_type
;
input
,
input
.
suggest_memory_format
());
policyFuncBackward
(
&
k_dim
,
&
k_type
);
// only support in32
auto
target_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
// get compute queue
target
.
toType
(
at
::
kInt
),
target
.
suggest_memory_format
());
auto
queue
=
torch_mlu
::
getCurQueue
();
auto
weight_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
weight
,
weight
.
suggest_memory_format
());
auto
output_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
output
,
output
.
suggest_memory_format
());
// set tensor descriptor
MluOpTensorDescriptor
input_desc
,
target_desc
,
weight_desc
,
output_desc
;
input_desc
.
set
(
input_contiguous
);
target_desc
.
set
(
target_contiguous
);
weight_desc
.
set
(
weight_contiguous
);
output_desc
.
set
(
output_contiguous
);
// get ptr of tensors
// get ptr of tensors
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
input
);
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
input
_contiguous
);
auto
input_ptr
=
input_impl
->
cnnlMalloc
();
auto
input_ptr
=
input_impl
->
cnnlMalloc
();
auto
target_impl
=
torch_mlu
::
getMluTensorImpl
(
target
);
auto
target_impl
=
torch_mlu
::
getMluTensorImpl
(
target
_contiguous
);
auto
target_ptr
=
target_impl
->
cnnlMalloc
();
auto
target_ptr
=
target_impl
->
cnnlMalloc
();
auto
weight_impl
=
torch_mlu
::
getMluTensorImpl
(
weight
);
auto
weight_impl
=
torch_mlu
::
getMluTensorImpl
(
weight
_contiguous
);
auto
weight_ptr
=
weight_impl
->
cnnlMalloc
();
auto
weight_ptr
=
weight_impl
->
cnnlMalloc
();
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output
);
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output
_contiguous
);
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
//
g
et
dtype of input
//
s
et
prefer computation performance and redcuntion approach
cnrtDataType_t
d_type
=
torch_mlu
::
toCnrtDtype
(
input
.
dtype
());
// backward only support MLUOP_COMPUTATION_HIGH_PRECISION
auto
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
)
;
mluOpComputationPreference_t
prefer
=
MLUOP_COMPUTATION_HIGH_PRECISION
;
auto
dim_n
=
input
.
size
(
0
)
;
mluOpLossReduction_t
reduction
=
MLUOP_LOSS_REDUCTION_NONE
;
CNLOG
(
INFO
)
<<
"Launch Kernel KernelFocalLossSigmoidBackward<<<Union"
auto
handle
=
mluOpGetCurrentHandle
();
<<
k_type
/
core_dim
<<
", "
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
// launch kernel
// launch kernel
KernelFocalLossSigmoidBackward
(
k_dim
,
k_type
,
queue
,
d_type
,
input_ptr
,
TORCH_MLUOP_CHECK
(
mluOpFocalLossSigmoidBackward
(
target_ptr
,
weight_ptr
,
gamma
,
alpha
,
dim_n
,
handle
,
prefer
,
reduction
,
input_desc
.
desc
(),
input_ptr
,
deal_n
,
dim_c
,
output_ptr
);
target_desc
.
desc
(),
target_ptr
,
weight_desc
.
desc
(),
weight_ptr
,
alpha
,
}
gamma
,
output_desc
.
desc
(),
output_ptr
));
void
sigmoid_focal_loss_forward_mlu
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
Tensor
output
,
float
gamma
,
float
alpha
)
{
SigmoidFocalLossForwardMLUKernelLauncher
(
input
,
target
,
weight
,
output
,
gamma
,
alpha
);
}
void
sigmoid_focal_loss_backward_mlu
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
Tensor
grad_input
,
float
gamma
,
float
alpha
)
{
SigmoidFocalLossBackwardMLUKernelLauncher
(
input
,
target
,
weight
,
grad_input
,
gamma
,
alpha
);
}
}
void
sigmoid_focal_loss_forward_impl
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
void
sigmoid_focal_loss_forward_impl
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
...
...
mmcv/ops/csrc/pytorch/mlu/iou3d_mlu.cpp
View file @
91da9643
...
@@ -10,114 +10,31 @@
...
@@ -10,114 +10,31 @@
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "mlu_common_helper.h"
#include "pytorch_mlu_helper.hpp"
void
KernelIou3d
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
data_type_input
,
const
void
*
boxes_dram
,
const
int
input_box_num
,
const
float
iou_threshold
,
void
*
workspace
,
void
*
output_size
,
void
*
output
);
int
selectType
(
uint32_t
use_job
,
int
box_num_per_core
)
{
// the box_num_per_core should be at least 256, otherwise the real IO
// bandwidth would be very low
while
(
box_num_per_core
<
256
&&
use_job
>=
4
)
{
box_num_per_core
*=
2
;
use_job
/=
2
;
}
return
use_job
;
}
static
cnnlStatus_t
policyFunc
(
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
,
int
&
core_num_per_class
,
const
int
input_box_num
)
{
uint32_t
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
uint32_t
job_limit
=
getJobLimitCapability
();
uint32_t
core_number
=
job_limit
;
int
box_num_per_core
=
(
input_box_num
+
core_number
-
1
)
/
core_number
;
int
use_job
=
selectType
(
job_limit
,
box_num_per_core
);
// initiate k_type as Union1
k_dim
->
x
=
core_dim
;
k_dim
->
y
=
1
;
k_dim
->
z
=
1
;
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
switch
(
job_limit
)
{
case
CN_KERNEL_CLASS_BLOCK
:
case
CN_KERNEL_CLASS_UNION
:
case
CN_KERNEL_CLASS_UNION2
:
case
CN_KERNEL_CLASS_UNION4
:
case
CN_KERNEL_CLASS_UNION8
:
case
CN_KERNEL_CLASS_UNION16
:
{
if
(
use_job
<
4
)
{
k_dim
->
x
=
1
;
*
k_type
=
CNRT_FUNC_TYPE_BLOCK
;
}
else
if
(
use_job
==
4
)
{
k_dim
->
x
=
core_dim
;
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
}
else
{
k_dim
->
x
=
use_job
;
*
k_type
=
(
cnrtFunctionType_t
)
use_job
;
}
};
break
;
default:
LOG
(
WARNING
)
<<
"[cnnlNms_v2]: got unsupported job limit number."
<<
" Use default CN_KERNEL_CLASS_UNION1 with UNION1 task."
;
}
return
CNNL_STATUS_SUCCESS
;
}
void
IoU3DNMS3DMLUKernelLauncher
(
Tensor
boxes
,
Tensor
&
keep
,
Tensor
&
keep_num
,
void
IoU3DNMS3DMLUKernelLauncher
(
Tensor
boxes
,
Tensor
&
keep
,
Tensor
&
keep_num
,
float
iou_threshold
)
{
float
iou_threshold
)
{
// dimension parameters check
TORCH_CHECK
(
boxes
.
dim
()
==
2
,
"boxes should be a 2d tensor, got "
,
boxes
.
dim
(),
"D"
);
TORCH_CHECK
(
boxes
.
size
(
1
)
==
7
,
"boxes should have 7 elements in dimension 1, got "
,
boxes
.
size
(
1
));
// data type check
TORCH_CHECK
(
boxes
.
scalar_type
()
==
at
::
kFloat
||
boxes
.
scalar_type
()
==
at
::
kHalf
,
"data type of boxes should be Float or Half, got "
,
boxes
.
scalar_type
());
if
(
boxes
.
numel
()
==
0
)
{
if
(
boxes
.
numel
()
==
0
)
{
return
;
return
;
}
}
const
size_t
max_input_num
=
2147483648
;
// 2^31, 2G num
TORCH_CHECK
(
boxes
.
numel
()
<
max_input_num
,
"boxes.numel() should be less than 2147483648, got "
,
boxes
.
numel
());
int
input_box_num
=
boxes
.
size
(
0
);
cnrtDataType_t
data_type_input
=
torch_mlu
::
toCnrtDtype
(
boxes
.
dtype
());
cnrtDim3_t
k_dim
;
cnrtJobType_t
k_type
;
int
core_num_per_class
;
policyFunc
(
&
k_dim
,
&
k_type
,
core_num_per_class
,
input_box_num
);
// transpose boxes (n, 7) to (7, n) for better performance
int
input_box_num
=
boxes
.
size
(
0
);
auto
boxes_t
=
boxes
.
transpose
(
0
,
1
);
auto
boxes_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
boxes
);
auto
boxes_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
boxes_t
);
auto
output
=
keep
.
to
(
boxes
.
options
().
dtype
(
at
::
kInt
));
auto
output
=
at
::
empty
({
input_box_num
},
boxes
.
options
().
dtype
(
at
::
kLong
));
auto
output_size
=
at
::
empty
({
1
},
boxes
.
options
().
dtype
(
at
::
kInt
));
auto
output_size
=
at
::
empty
({
1
},
boxes
.
options
().
dtype
(
at
::
kInt
));
// workspace
MluOpTensorDescriptor
boxes_desc
,
output_desc
;
const
int
info_num
=
7
;
// x, y,z, dx, dy, dz,angle
boxes_desc
.
set
(
boxes_
);
size_t
space_size
=
0
;
output_desc
.
set
(
output
);
if
(
boxes
.
scalar_type
()
==
at
::
kHalf
)
{
space_size
=
input_box_num
*
sizeof
(
int16_t
)
*
info_num
+
input_box_num
*
sizeof
(
float
)
+
sizeof
(
float
);
}
else
{
space_size
=
input_box_num
*
sizeof
(
float
)
*
(
info_num
+
1
)
+
sizeof
(
float
);
}
auto
workspace
=
at
::
empty
(
space_size
,
boxes
.
options
().
dtype
(
at
::
kByte
));
// workspace
size_t
workspace_size
=
0
;
auto
handle
=
mluOpGetCurrentHandle
();
TORCH_MLUOP_CHECK
(
mluOpGetNmsWorkspaceSize
(
handle
,
boxes_desc
.
desc
(),
NULL
,
&
workspace_size
));
auto
workspace
=
at
::
empty
(
workspace_size
,
boxes
.
options
().
dtype
(
at
::
kByte
));
// get compute queue
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
auto
boxes_impl
=
torch_mlu
::
getMluTensorImpl
(
boxes_
);
auto
boxes_impl
=
torch_mlu
::
getMluTensorImpl
(
boxes_
);
auto
boxes_ptr
=
boxes_impl
->
cnnlMalloc
();
auto
boxes_ptr
=
boxes_impl
->
cnnlMalloc
();
auto
workspace_impl
=
torch_mlu
::
getMluTensorImpl
(
workspace
);
auto
workspace_impl
=
torch_mlu
::
getMluTensorImpl
(
workspace
);
...
@@ -127,11 +44,29 @@ void IoU3DNMS3DMLUKernelLauncher(Tensor boxes, Tensor &keep, Tensor &keep_num,
...
@@ -127,11 +44,29 @@ void IoU3DNMS3DMLUKernelLauncher(Tensor boxes, Tensor &keep, Tensor &keep_num,
auto
output_size_impl
=
torch_mlu
::
getMluTensorImpl
(
keep_num
);
auto
output_size_impl
=
torch_mlu
::
getMluTensorImpl
(
keep_num
);
auto
output_size_ptr
=
output_size_impl
->
cnnlMalloc
();
auto
output_size_ptr
=
output_size_impl
->
cnnlMalloc
();
uint32_t
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
// nms desc
CNLOG
(
INFO
)
<<
"Launch Kernel KernelIou3d<<<Union"
<<
k_type
/
core_dim
mluOpNmsDescriptor_t
nms_desc
;
<<
", "
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
const
mluOpNmsBoxPointMode_t
box_mode
=
(
mluOpNmsBoxPointMode_t
)
0
;
KernelIou3d
(
k_dim
,
k_type
,
queue
,
data_type_input
,
boxes_ptr
,
input_box_num
,
const
mluOpNmsOutputMode_t
output_mode
=
(
mluOpNmsOutputMode_t
)
0
;
iou_threshold
,
workspace_ptr
,
output_size_ptr
,
output_ptr
);
const
mluOpNmsAlgo_t
algo
=
(
mluOpNmsAlgo_t
)
0
;
const
mluOpNmsMethodMode_t
method_mode
=
(
mluOpNmsMethodMode_t
)
0
;
const
float
soft_nms_sigma
=
0.0
;
const
float
confidence_threshold
=
0.0
;
const
int
input_layout
=
0
;
const
bool
pad_to_max_output_size
=
false
;
const
int
max_output_size
=
input_box_num
;
const
float
offset
=
0.0
;
TORCH_MLUOP_CHECK
(
mluOpCreateNmsDescriptor
(
&
nms_desc
));
TORCH_MLUOP_CHECK
(
mluOpSetNmsDescriptor
(
nms_desc
,
box_mode
,
output_mode
,
algo
,
method_mode
,
iou_threshold
,
soft_nms_sigma
,
max_output_size
,
confidence_threshold
,
offset
,
input_layout
,
pad_to_max_output_size
));
TORCH_MLUOP_CHECK
(
mluOpNms
(
handle
,
nms_desc
,
boxes_desc
.
desc
(),
boxes_ptr
,
NULL
,
NULL
,
workspace_ptr
,
workspace_size
,
output_desc
.
desc
(),
output_ptr
,
output_size_ptr
));
TORCH_MLUOP_CHECK
(
mluOpDestroyNmsDescriptor
(
nms_desc
));
}
}
void
iou3d_nms3d_forward_mlu
(
const
Tensor
boxes
,
Tensor
&
keep
,
Tensor
&
keep_num
,
void
iou3d_nms3d_forward_mlu
(
const
Tensor
boxes
,
Tensor
&
keep
,
Tensor
&
keep_num
,
...
...
mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.cpp
0 → 100644
View file @
91da9643
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"
// Descriptors
mluOpDataType_t
getMluOpDataType
(
const
caffe2
::
TypeMeta
&
data_type
)
{
const
std
::
map
<
std
::
string
,
mluOpDataType_t
>
mapping_type
=
{
{
std
::
string
(
"c10::Half"
),
MLUOP_DTYPE_HALF
},
{
std
::
string
(
"float"
),
MLUOP_DTYPE_FLOAT
},
{
std
::
string
(
"double"
),
MLUOP_DTYPE_DOUBLE
},
{
std
::
string
(
"int8"
),
MLUOP_DTYPE_INT8
},
{
std
::
string
(
"signed char"
),
MLUOP_DTYPE_INT8
},
{
std
::
string
(
"short int"
),
MLUOP_DTYPE_INT16
},
{
std
::
string
(
"short"
),
MLUOP_DTYPE_INT16
},
{
std
::
string
(
"int"
),
MLUOP_DTYPE_INT32
},
{
std
::
string
(
"long int"
),
MLUOP_DTYPE_INT64
},
{
std
::
string
(
"long"
),
MLUOP_DTYPE_INT64
},
{
std
::
string
(
"unsigned char"
),
MLUOP_DTYPE_UINT8
},
{
std
::
string
(
"bool"
),
MLUOP_DTYPE_BOOL
},
{
std
::
string
(
"c10::complex<c10::Half>"
),
MLUOP_DTYPE_COMPLEX_HALF
},
{
std
::
string
(
"c10::complex<float>"
),
MLUOP_DTYPE_COMPLEX_FLOAT
}};
if
(
mapping_type
.
find
(
std
::
string
(
data_type
.
name
()))
!=
mapping_type
.
end
())
{
return
mapping_type
.
find
(
std
::
string
(
data_type
.
name
()))
->
second
;
}
return
MLUOP_DTYPE_INVALID
;
}
// laytout
mluOpTensorLayout_t
getMluOpSuggestLayout
(
const
at
::
Tensor
&
input
)
{
auto
suggest_memory_format
=
input
.
suggest_memory_format
();
mluOpTensorLayout_t
layout
=
MLUOP_LAYOUT_ARRAY
;
switch
(
input
.
dim
())
{
case
4
:
layout
=
(
suggest_memory_format
==
at
::
MemoryFormat
::
ChannelsLast
)
?
MLUOP_LAYOUT_NHWC
:
MLUOP_LAYOUT_NCHW
;
break
;
case
5
:
layout
=
(
suggest_memory_format
==
at
::
MemoryFormat
::
ChannelsLast3d
)
?
MLUOP_LAYOUT_NDHWC
:
MLUOP_LAYOUT_NCDHW
;
break
;
default:
layout
=
MLUOP_LAYOUT_ARRAY
;
}
return
layout
;
}
mluOpReduceMode_t
getMluOpReduceMode
(
const
reduce_t
reduce_type
)
{
const
std
::
map
<
reduce_t
,
mluOpReduceMode_t
>
mapping_type
=
{
{
reduce_t
::
MAX
,
MLUOP_REDUCE_DMAX
},
{
reduce_t
::
SUM
,
MLUOP_REDUCE_DSUM
},
{
reduce_t
::
MEAN
,
MLUOP_REDUCE_DMEAN
}};
if
(
mapping_type
.
find
(
reduce_type
)
!=
mapping_type
.
end
())
{
return
mapping_type
.
find
(
reduce_type
)
->
second
;
}
else
{
TORCH_CHECK
(
false
,
"Unsupported reduce type: "
,
to_string
(
reduce_type
));
return
MLUOP_REDUCE_DSUM
;
}
}
void
MluOpTensorDescriptor
::
set
(
Tensor
t
)
{
mluOpDataType_t
data_type
=
getMluOpDataType
(
t
.
dtype
());
mluOpTensorLayout_t
layout
=
getMluOpSuggestLayout
(
t
);
int
t_dim
=
t
.
dim
();
std
::
vector
<
int
>
dim_array
;
if
(
t_dim
==
0
)
{
dim_array
.
push_back
(
1
);
// ScalarTensor(0-dim 1-item Tensor) view like size = 1 as default;
}
else
{
for
(
int
i
=
0
;
i
<
t_dim
;
i
++
)
{
dim_array
.
push_back
(
static_cast
<
int
>
(
t
.
sizes
().
vec
()[
i
]));
}
}
set_desc
(
t
,
layout
,
data_type
,
dim_array
);
}
void
MluOpTensorDescriptor
::
set_with_layout
(
Tensor
t
,
mluOpTensorLayout_t
layout
)
{
mluOpDataType_t
data_type
=
getMluOpDataType
(
t
.
dtype
());
int
t_dim
=
t
.
dim
();
std
::
vector
<
int
>
shape_info
=
checkUpperBoundAndCastTo
<
int
>
(
t
.
sizes
().
vec
());
std
::
vector
<
int
>
stride_info
=
checkUpperBoundAndCastTo
<
int
>
(
t
.
strides
().
vec
());
if
(
layout
==
MLUOP_LAYOUT_NHWC
||
layout
==
MLUOP_LAYOUT_NDHWC
||
layout
==
MLUOP_LAYOUT_NLC
)
{
convertShapeAndStride
(
shape_info
,
stride_info
);
}
else
if
(
layout
==
MLUOP_LAYOUT_HWCN
)
{
auto
convertDepthWiseConvShapeStride
=
[](
const
std
::
vector
<
int64_t
>&
vec
,
std
::
vector
<
int
>&
target_vec
,
std
::
vector
<
int
>&
stride_vec
)
{
// NCHW --> HWCN
target_vec
[
0
]
=
static_cast
<
int
>
(
vec
[
2
]);
target_vec
[
1
]
=
static_cast
<
int
>
(
vec
[
3
]);
target_vec
[
2
]
=
static_cast
<
int
>
(
vec
[
1
]);
target_vec
[
3
]
=
static_cast
<
int
>
(
vec
[
0
]);
// Calculate Stride just like contiguous of HWCN.
stride_vec
[
3
]
=
1
;
stride_vec
[
2
]
=
target_vec
[
3
]
*
stride_vec
[
3
];
stride_vec
[
1
]
=
target_vec
[
2
]
*
stride_vec
[
2
];
stride_vec
[
0
]
=
target_vec
[
1
]
*
stride_vec
[
1
];
};
convertDepthWiseConvShapeStride
(
t
.
sizes
().
vec
(),
shape_info
,
stride_info
);
}
TORCH_CHECK
(
mluOpSetTensorDescriptorEx
(
desc_
,
layout
,
data_type
,
t_dim
,
shape_info
.
data
(),
stride_info
.
data
())
==
MLUOP_STATUS_SUCCESS
,
"mluOpSetTensorDescriptorEx execution failed."
);
}
void
MluOpTensorDescriptor
::
set_desc
(
const
at
::
Tensor
&
t
,
mluOpTensorLayout_t
layout
,
mluOpDataType_t
dtype
,
std
::
vector
<
int
>&
dims
)
{
int
dimNb
=
dims
.
size
();
TORCH_MLUOP_CHECK
(
mluOpSetTensorDescriptor
(
desc_
,
layout
,
dtype
,
dimNb
,
dims
.
data
()));
}
// Handles
std
::
once_flag
mmcv_mluop_init_flag
;
std
::
mutex
mmcv_mluop_mutex
;
static
std
::
vector
<
MluOpHandle
>
mmcv_mluop_handles
;
mluOpHandle_t
mluOpGetCurrentHandle
(
c10
::
DeviceIndex
device_index
)
{
std
::
call_once
(
mmcv_mluop_init_flag
,
[]()
// Init mmcv_mluop_handles 1-device <-> 1-handle
{
c10
::
DeviceIndex
num_devices
=
torch_mlu
::
device_count
();
mmcv_mluop_handles
.
resize
(
num_devices
);
});
if
(
device_index
==
-
1
)
{
device_index
=
torch_mlu
::
current_device
();
}
std
::
lock_guard
<
std
::
mutex
>
mmcv_mluop_guard
(
mmcv_mluop_mutex
);
auto
queue
=
torch_mlu
::
getCurrentQueue
(
device_index
).
queue
();
mmcv_mluop_handles
[
device_index
].
setQueue
(
queue
);
return
mmcv_mluop_handles
[
device_index
].
handle
;
}
mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h
0 → 100644
View file @
91da9643
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#pragma once
#include <ATen/ATen.h>
#include <c10/core/ScalarType.h>
#include "aten.h"
#include "mlu_op.h"
#include "pytorch_device_registry.hpp"
#define MLUOP_MAJOR 0
#define MLUOP_MINOR 8
#define MLUOP_PATCHLEVEL 1
/*************************************************************************
* This MACRO contains operations of simple tensor to mlu-tensor.
* _contiguous, _desc, _impl, _ptr will be automatically generated in
* this MACRO.
*************************************************************************/
#define INITIAL_MLU_PARAM_WITH_TENSOR(NAME) \
auto NAME##_contigous = torch_mlu::cnnl::ops::cnnl_contiguous( \
NAME, NAME.suggest_memory_format()); \
MluOpTensorDescriptor NAME##_desc; \
NAME##_desc.set(NAME##_contigous); \
auto NAME##_impl = torch_mlu::getMluTensorImpl(NAME##_contigous); \
auto NAME##_ptr = NAME##_impl->cnnlMalloc();
#ifndef TORCH_MLUOP_CHECK
#define TORCH_MLUOP_CHECK(EXPR) \
do { \
mluOpStatus_t status = EXPR; \
if (status != MLUOP_STATUS_SUCCESS) { \
CNLOG(ERROR) << ""; \
TORCH_CHECK(false, "MLUOPS error: ", mluOpGetErrorString(status)); \
} \
} while (0);
#endif
enum
class
reduce_t
{
SUM
=
0
,
MEAN
=
1
,
MAX
=
2
};
inline
std
::
string
to_string
(
reduce_t
reduce_type
)
{
if
(
reduce_type
==
reduce_t
::
MAX
)
{
return
"max"
;
}
else
if
(
reduce_type
==
reduce_t
::
MEAN
)
{
return
"mean"
;
}
else
if
(
reduce_type
==
reduce_t
::
SUM
)
{
return
"sum"
;
}
else
{
return
"unknown reduce type"
;
}
}
mluOpDataType_t
getMluOpDataType
(
const
caffe2
::
TypeMeta
&
data_type
);
mluOpTensorLayout_t
getMluOpSuggestLayout
(
const
at
::
Tensor
&
input
);
mluOpReduceMode_t
getMluOpReduceMode
(
const
reduce_t
reduce_type
);
class
MluOpTensorDescriptor
{
public:
MluOpTensorDescriptor
()
{
TORCH_MLUOP_CHECK
(
mluOpCreateTensorDescriptor
(
&
desc_
));
};
~
MluOpTensorDescriptor
()
{
TORCH_MLUOP_CHECK
(
mluOpDestroyTensorDescriptor
(
desc_
));
}
void
set
(
at
::
Tensor
);
void
set_with_layout
(
at
::
Tensor
,
mluOpTensorLayout_t
layout
);
mluOpTensorDescriptor_t
desc
()
{
return
desc_
;
}
private:
mluOpTensorDescriptor_t
desc_
;
void
set_desc
(
const
at
::
Tensor
&
,
mluOpTensorLayout_t
,
mluOpDataType_t
,
std
::
vector
<
int
>&
dims
);
};
mluOpHandle_t
mluOpGetCurrentHandle
(
c10
::
DeviceIndex
device_index
=
-
1
);
class
MluOpHandle
{
public:
MluOpHandle
()
:
handle
(
nullptr
)
{
TORCH_MLUOP_CHECK
(
mluOpCreate
(
&
handle
));
}
~
MluOpHandle
()
{
if
(
handle
)
{
TORCH_MLUOP_CHECK
(
mluOpDestroy
(
handle
));
handle
=
nullptr
;
}
}
void
setQueue
(
cnrtQueue_t
queue
)
{
TORCH_MLUOP_CHECK
(
mluOpSetQueue
(
handle
,
queue
));
}
mluOpHandle_t
handle
;
};
// modify tensor size and stride order based on
// channels_first to channels_last or channels_last_3d.
// which this is not same with pytorch original layout,
// this real layout is based on data storage real order.
// example: modify channels_last tensor dim to nhwc tensor desc.
// N C H W --> N H W C
// C*H*W 1 W C --> C*H*W W C 1
template
<
typename
T
>
void
convertShapeAndStride
(
std
::
vector
<
T
>&
shape_info
,
std
::
vector
<
T
>&
stride_info
)
{
TORCH_MLU_CHECK
(
shape_info
.
size
()
==
stride_info
.
size
(),
"shape size need equal to stride size."
);
const
int
dim
=
shape_info
.
size
();
std
::
vector
<
T
>
temp_shape_info
(
dim
);
std
::
vector
<
T
>
temp_stride_info
(
dim
);
temp_shape_info
[
0
]
=
shape_info
[
0
];
temp_stride_info
[
0
]
=
stride_info
[
0
];
for
(
size_t
i
=
0
;
i
<
dim
-
1
;
++
i
)
{
const
int
index
=
(
i
+
1
)
%
(
dim
-
1
)
+
1
;
temp_shape_info
[
i
+
1
]
=
shape_info
[
index
];
temp_stride_info
[
i
+
1
]
=
stride_info
[
index
];
}
shape_info
.
assign
(
temp_shape_info
.
begin
(),
temp_shape_info
.
end
());
stride_info
.
assign
(
temp_stride_info
.
begin
(),
temp_stride_info
.
end
());
}
// torch tensor provides int64_t type of shape and stride,
// but mluops descriptor requires type int32.
// use this function to ensure safe CAST, or report an error.
template
<
typename
DST_T
,
typename
SRC_T
>
std
::
vector
<
DST_T
>
checkUpperBoundAndCastTo
(
const
std
::
vector
<
SRC_T
>&
input
)
{
std
::
vector
<
DST_T
>
output
;
output
.
reserve
(
input
.
size
());
for
(
const
auto
&
val
:
input
)
{
if
(
val
>
std
::
numeric_limits
<
DST_T
>::
max
())
{
TORCH_MLU_CHECK
(
false
,
"Requires dim size not greater than "
,
std
::
numeric_limits
<
DST_T
>::
max
(),
". But got "
,
val
,
"."
);
}
output
.
push_back
(
static_cast
<
DST_T
>
(
val
));
}
return
output
;
}
mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp
View file @
91da9643
...
@@ -9,396 +9,104 @@
...
@@ -9,396 +9,104 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
*************************************************************************/
#include "mlu_common_helper.h"
#include "pytorch_device_registry.hpp"
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
#include "pytorch_mlu_helper.hpp"
#define MIN(a, b) (((a) < (b)) ? (a) : (b))
Tensor
MsDeformAttnForwardLauncher
(
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
void
KernelMsDeformAttnForward
(
const
Tensor
&
level_start_index
,
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
Tensor
&
sampling_loc
,
const
cnrtDataType_t
d_type
,
const
char
*
data_value_gdram
,
const
Tensor
&
attn_weight
,
const
char
*
data_spatial_shapes_gdram
,
const
int
im2col_step
)
{
const
char
*
data_level_start_index_gdram
,
auto
handle
=
mluOpGetCurrentHandle
();
const
char
*
data_sampling_loc_gdram
,
const
char
*
data_attn_weight_gdram
,
const
int32_t
batch_size
,
const
int32_t
num_keys
,
const
int32_t
num_heads
,
const
int32_t
channels
,
const
int32_t
num_levels
,
const
int32_t
num_queries
,
const
int32_t
num_points
,
char
*
data_col_gdram
);
void
KernelMsDeformAttnBackward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
float
*
data_value
,
const
int32_t
*
spatial_shapes
,
const
int32_t
*
data_level_start_index
,
const
float
*
data_sampling_loc
,
const
float
*
data_attn_weight
,
const
float
*
grad_output
,
const
int32_t
batch_size
,
const
int32_t
num_keys
,
const
int32_t
num_heads
,
const
int32_t
channels
,
const
int32_t
num_levels
,
const
int32_t
num_queries
,
const
int32_t
num_points
,
float
*
grad_value
,
float
*
grad_sampling_loc
,
float
*
grad_attn_weight
);
// policy function
static
void
policyFuncForward
(
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
,
const
int
batch_size
,
const
int
num_queries
,
const
int
num_heads
)
{
k_dim
->
x
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
k_dim
->
y
=
MIN
((
batch_size
*
num_queries
*
num_heads
+
k_dim
->
x
-
1
)
/
k_dim
->
x
,
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
));
k_dim
->
z
=
1
;
#if __BANG_ARCH__ == 520
*
k_type
=
CNRT_FUNC_TYPE_BLOCK
;
#else
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
#endif
}
// policy function for backward
static
void
policyFuncBackward
(
const
int32_t
batch_size
,
const
int32_t
num_queries
,
const
int32_t
num_heads
,
const
int32_t
num_levels
,
cnrtFunctionType_t
*
k_type
,
cnrtDim3_t
*
k_dim
)
{
size_t
cluster_limit
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
size_t
core_limit
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
k_dim
->
x
=
core_limit
;
int32_t
total_num
=
batch_size
*
num_queries
*
num_heads
*
num_levels
;
size_t
total_num_align
=
CEIL_ALIGN
(
total_num
,
core_limit
);
k_dim
->
y
=
(
total_num_align
/
core_limit
)
>
cluster_limit
?
cluster_limit
:
(
total_num_align
/
core_limit
);
k_dim
->
z
=
1
;
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
}
Tensor
ms_deform_attn_mlu_forward
(
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
int
im2col_step
)
{
// check contiguous
AT_ASSERTM
(
value
.
is_contiguous
(),
"value tensor has to be contiguous"
);
AT_ASSERTM
(
spatial_shapes
.
is_contiguous
(),
"spatial_shapes tensor has to be contiguous"
);
AT_ASSERTM
(
level_start_index
.
is_contiguous
(),
"level_start_index tensor has to be contiguous"
);
AT_ASSERTM
(
sampling_loc
.
is_contiguous
(),
"sampling_loc tensor has to be contiguous"
);
AT_ASSERTM
(
attn_weight
.
is_contiguous
(),
"attn_weight tensor has to be contiguous"
);
// check datatype
TORCH_CHECK
((
value
.
scalar_type
()
==
at
::
kFloat
),
"value type should be Float, got "
,
value
.
scalar_type
(),
"."
);
TORCH_CHECK
((
spatial_shapes
.
scalar_type
()
==
at
::
kInt
||
spatial_shapes
.
scalar_type
()
==
at
::
kLong
),
"spatial_shapes type should be Int, got "
,
spatial_shapes
.
scalar_type
(),
"."
);
TORCH_CHECK
((
level_start_index
.
scalar_type
()
==
at
::
kInt
||
level_start_index
.
scalar_type
()
==
at
::
kLong
),
"level_start_index type should be Int, got "
,
level_start_index
.
scalar_type
(),
"."
);
TORCH_CHECK
((
sampling_loc
.
scalar_type
()
==
at
::
kFloat
),
"sampling_loc type should be Float, got "
,
sampling_loc
.
scalar_type
(),
"."
);
TORCH_CHECK
((
attn_weight
.
scalar_type
()
==
at
::
kFloat
),
"attn_weight type should be Float, got "
,
attn_weight
.
scalar_type
(),
"."
);
// check shape
TORCH_CHECK
(
value
.
dim
()
==
4
,
"value should be a 4d tensor, got "
,
value
.
dim
(),
"D."
);
TORCH_CHECK
(
spatial_shapes
.
dim
()
==
2
,
"spatial_shapes should be a 2d tensor, got "
,
spatial_shapes
.
dim
(),
"D."
);
TORCH_CHECK
(
level_start_index
.
dim
()
==
1
,
"level_start_index should be a 1d tensor, got "
,
level_start_index
.
dim
(),
"D."
);
TORCH_CHECK
(
sampling_loc
.
dim
()
==
6
,
"sampling_loc should be a 6d tensor, got "
,
sampling_loc
.
dim
(),
"D."
);
TORCH_CHECK
(
attn_weight
.
dim
()
==
5
,
"attn_weight should be a 5d tensor, got "
,
attn_weight
.
dim
(),
"D."
);
const
int
batch_size
=
value
.
size
(
0
);
const
int
batch_size
=
value
.
size
(
0
);
const
int
num_keys
=
value
.
size
(
1
);
const
int
num_heads
=
value
.
size
(
2
);
const
int
num_heads
=
value
.
size
(
2
);
const
int
channels
=
value
.
size
(
3
);
const
int
channels
=
value
.
size
(
3
);
const
int
num_levels
=
spatial_shapes
.
size
(
0
);
const
int
num_queries
=
sampling_loc
.
size
(
1
);
const
int
num_queries
=
sampling_loc
.
size
(
1
);
const
int
num_points
=
sampling_loc
.
size
(
4
);
TORCH_CHECK
(
spatial_shapes
.
size
(
1
)
==
2
,
"the 2nd dimensions of spatial_shapes should be 2, got "
,
spatial_shapes
.
size
(
1
),
"."
);
TORCH_CHECK
(
sampling_loc
.
size
(
5
)
==
2
,
"the 6th dimensions of sampling_loc should be 2, got "
,
sampling_loc
.
size
(
5
),
"."
);
TORCH_CHECK
((
sampling_loc
.
size
(
0
)
==
batch_size
),
"the 1st dimensions of sampling_loc should be batch_size, "
,
"but now the 1st dimension of sampling_loc is "
,
sampling_loc
.
size
(
0
),
", and batch_size is "
,
batch_size
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
0
)
==
batch_size
),
"the 1st dimensions of attn_weight should be batch_size, "
,
"but now the 1st dimension of attn_weight is "
,
attn_weight
.
size
(
0
),
", and batch_size is "
,
batch_size
,
"."
);
TORCH_CHECK
((
sampling_loc
.
size
(
2
)
==
num_heads
),
"the 3rd dimensions of sampling_loc should be num_heads, "
,
"but now the 3rd dimension of sampling_loc is "
,
sampling_loc
.
size
(
2
),
", and num_heads is "
,
num_heads
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
2
)
==
num_heads
),
"the 3rd dimensions of attn_weight should be num_heads, "
,
"but now the 3rd dimension of attn_weight is "
,
attn_weight
.
size
(
2
),
", and num_heads is "
,
num_heads
,
"."
);
TORCH_CHECK
((
level_start_index
.
size
(
0
)
==
num_levels
),
"the 1st dimensions of level_start_index should be num_levels, "
,
"but now the 1st dimension of level_start_index is "
,
level_start_index
.
size
(
0
),
", and num_levels is "
,
num_levels
,
"."
);
TORCH_CHECK
((
sampling_loc
.
size
(
3
)
==
num_levels
),
"the 4th dimensions of sampling_loc should be num_levels, "
,
"but now the 4th dimension of sampling_loc is "
,
sampling_loc
.
size
(
3
),
", and num_levels is "
,
num_levels
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
3
)
==
num_levels
),
"the 4th dimensions of attn_weight should be num_levels, "
,
"but now the 4th dimension of attn_weight is "
,
attn_weight
.
size
(
3
),
", and num_levels is "
,
num_levels
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
1
)
==
num_queries
),
"the 2nd dimensions of attn_weight should be num_queries, "
,
"but now the 2nd dimension of attn_weight is "
,
attn_weight
.
size
(
1
),
", and num_queries is "
,
num_queries
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
4
)
==
num_points
),
"the 5th dimensions of attn_weight should be num_points, "
,
"but now the 5th dimension of attn_weight is "
,
attn_weight
.
size
(
4
),
", and num_points is "
,
num_points
,
"."
);
auto
output
=
at
::
zeros
({
batch_size
,
num_queries
,
num_heads
,
channels
},
auto
output
=
at
::
zeros
({
batch_size
,
num_queries
,
num_heads
,
channels
},
value
.
options
());
value
.
options
());
auto
spatial_shapes_int
=
spatial_shapes
.
to
(
at
::
kInt
);
// large tensor check
auto
level_start_index_int
=
level_start_index
.
to
(
at
::
kInt
);
const
size_t
max_input_size
=
2147483648
;
INITIAL_MLU_PARAM_WITH_TENSOR
(
output
);
TORCH_CHECK
(
value
.
numel
()
<
max_input_size
,
INITIAL_MLU_PARAM_WITH_TENSOR
(
value
);
"value element num should be less than 2^31, got "
,
value
.
numel
(),
INITIAL_MLU_PARAM_WITH_TENSOR
(
spatial_shapes_int
);
"."
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
level_start_index_int
);
TORCH_CHECK
(
sampling_loc
.
numel
()
<
max_input_size
,
INITIAL_MLU_PARAM_WITH_TENSOR
(
sampling_loc
);
"sampling_loc element num should be less than 2^31, got "
,
INITIAL_MLU_PARAM_WITH_TENSOR
(
attn_weight
);
sampling_loc
.
numel
(),
"."
);
TORCH_CHECK
(
output
.
numel
()
<
max_input_size
,
TORCH_MLUOP_CHECK
(
mluOpMsDeformAttnForward
(
"output element num should be less than 2^31, got "
,
handle
,
value_desc
.
desc
(),
value_ptr
,
spatial_shapes_int_desc
.
desc
(),
output
.
numel
(),
"."
);
spatial_shapes_int_ptr
,
level_start_index_int_desc
.
desc
(),
level_start_index_int_ptr
,
sampling_loc_desc
.
desc
(),
sampling_loc_ptr
,
// check zero element
attn_weight_desc
.
desc
(),
attn_weight_ptr
,
im2col_step
,
output_desc
.
desc
(),
TORCH_CHECK
(
batch_size
!=
0
,
"batch_size should not be zero"
);
output_ptr
));
TORCH_CHECK
(
num_heads
!=
0
,
"num_heads should not be zero"
);
TORCH_CHECK
(
channels
!=
0
,
"channels should not be zero"
);
TORCH_CHECK
(
num_queries
!=
0
,
"num_queries should not be zero"
);
if
(
num_keys
==
0
||
num_levels
==
0
||
num_points
==
0
)
{
return
output
;
}
// calculate task dimension
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
policyFuncForward
(
&
k_dim
,
&
k_type
,
batch_size
,
num_queries
,
num_heads
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
auto
spatial_shapes_
=
spatial_shapes
.
to
(
at
::
kInt
);
auto
level_start_index_
=
level_start_index
.
to
(
at
::
kInt
);
// get ptr of tensors
auto
value_impl
=
torch_mlu
::
getMluTensorImpl
(
value
);
auto
value_ptr
=
value_impl
->
cnnlMalloc
();
auto
spatial_shapes_impl
=
torch_mlu
::
getMluTensorImpl
(
spatial_shapes_
);
auto
spatial_shapes_ptr
=
spatial_shapes_impl
->
cnnlMalloc
();
auto
level_start_index_impl
=
torch_mlu
::
getMluTensorImpl
(
level_start_index_
);
auto
level_start_index_ptr
=
level_start_index_impl
->
cnnlMalloc
();
auto
sampling_loc_impl
=
torch_mlu
::
getMluTensorImpl
(
sampling_loc
);
auto
sampling_loc_ptr
=
sampling_loc_impl
->
cnnlMalloc
();
auto
attn_weight_impl
=
torch_mlu
::
getMluTensorImpl
(
attn_weight
);
auto
attn_weight_ptr
=
attn_weight_impl
->
cnnlMalloc
();
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output
);
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
// get compute dtype of input
cnrtDataType_t
data_type
=
torch_mlu
::
toCnrtDtype
(
value
.
dtype
());
// launch kernel
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernelMsDeformAttnForward<<<"
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
KernelMsDeformAttnForward
(
k_dim
,
k_type
,
queue
,
data_type
,
(
char
*
)
value_ptr
,
(
char
*
)
spatial_shapes_ptr
,
(
char
*
)
level_start_index_ptr
,
(
char
*
)
sampling_loc_ptr
,
(
char
*
)
attn_weight_ptr
,
batch_size
,
num_keys
,
num_heads
,
channels
,
num_levels
,
num_queries
,
num_points
,
(
char
*
)
output_ptr
);
output
=
output
.
view
({
batch_size
,
num_queries
,
num_heads
*
channels
});
output
=
output
.
view
({
batch_size
,
num_queries
,
num_heads
*
channels
});
return
output
;
return
output
;
}
}
void
ms_d
eform
_a
ttn
_mlu_b
ackward
(
void
MsD
eform
A
ttn
B
ackward
Launcher
(
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
Tensor
&
grad_output
,
Tensor
&
grad_value
,
const
Tensor
&
attn_weight
,
const
Tensor
&
grad_output
,
Tensor
&
grad_value
,
Tensor
&
grad_sampling_loc
,
Tensor
&
grad_attn_weight
,
Tensor
&
grad_sampling_loc
,
Tensor
&
grad_attn_weight
,
const
int
im2col_step
)
{
const
int
im2col_step
)
{
// check contiguous
auto
handle
=
mluOpGetCurrentHandle
();
AT_ASSERTM
(
value
.
is_contiguous
(),
"value tensor has to be contiguous"
);
auto
spatial_shapes_int
=
spatial_shapes
.
to
(
at
::
kInt
);
AT_ASSERTM
(
spatial_shapes
.
is_contiguous
(),
auto
level_start_index_int
=
level_start_index
.
to
(
at
::
kInt
);
"spatial_shapes tensor has to be contiguous"
);
AT_ASSERTM
(
level_start_index
.
is_contiguous
(),
"level_start_index tensor has to be contiguous"
);
AT_ASSERTM
(
sampling_loc
.
is_contiguous
(),
"sampling_loc tensor has to be contiguous"
);
AT_ASSERTM
(
attn_weight
.
is_contiguous
(),
"attn_weight tensor has to be contiguous"
);
AT_ASSERTM
(
grad_output
.
is_contiguous
(),
"grad_output tensor has to be contiguous"
);
// check datatype
TORCH_CHECK
((
value
.
scalar_type
()
==
at
::
kFloat
),
"value type should be Float, got "
,
value
.
scalar_type
(),
"."
);
TORCH_CHECK
((
spatial_shapes
.
scalar_type
()
==
at
::
kInt
||
spatial_shapes
.
scalar_type
()
==
at
::
kLong
),
"spatial_shapes type should be Int, got "
,
spatial_shapes
.
scalar_type
(),
"."
);
TORCH_CHECK
((
level_start_index
.
scalar_type
()
==
at
::
kInt
||
level_start_index
.
scalar_type
()
==
at
::
kLong
),
"level_start_index type should be Int, got "
,
level_start_index
.
scalar_type
(),
"."
);
TORCH_CHECK
((
sampling_loc
.
scalar_type
()
==
at
::
kFloat
),
"sampling_loc type should be Float, got "
,
sampling_loc
.
scalar_type
(),
"."
);
TORCH_CHECK
((
attn_weight
.
scalar_type
()
==
at
::
kFloat
),
"attn_weight type should be Float, got "
,
attn_weight
.
scalar_type
(),
"."
);
TORCH_CHECK
((
grad_output
.
scalar_type
()
==
at
::
kFloat
),
"grad_output type should be Float, got "
,
grad_output
.
scalar_type
(),
"."
);
const
int
batch_size
=
value
.
size
(
0
);
const
int
batch_size
=
value
.
size
(
0
);
const
int
num_keys
=
value
.
size
(
1
);
const
int
num_heads
=
value
.
size
(
2
);
const
int
num_heads
=
value
.
size
(
2
);
const
int
channels
=
value
.
size
(
3
);
const
int
channels
=
value
.
size
(
3
);
const
int
num_levels
=
spatial_shapes
.
size
(
0
);
const
int
num_queries
=
sampling_loc
.
size
(
1
);
const
int
num_queries
=
sampling_loc
.
size
(
1
);
const
int
num_points
=
sampling_loc
.
size
(
4
);
// Check shape.
TORCH_CHECK
(
spatial_shapes
.
size
(
1
)
==
2
,
"the 2nd dimensions of spatial_shapes should be 2, got "
,
spatial_shapes
.
size
(
1
),
"."
);
TORCH_CHECK
((
level_start_index
.
size
(
0
)
==
num_levels
),
"the 1st dimensions of level_start_index should be num_levels, "
,
"but now the 1st dimension of level_start_index is "
,
level_start_index
.
size
(
0
),
", and num_levels is "
,
num_levels
,
"."
);
TORCH_CHECK
((
sampling_loc
.
size
(
0
)
==
batch_size
),
"the 1st dimensions of sampling_loc should be batch_size, "
,
"but now the 1st dimension of sampling_loc is "
,
sampling_loc
.
size
(
0
),
", and batch_size is "
,
batch_size
,
"."
);
TORCH_CHECK
((
sampling_loc
.
size
(
2
)
==
num_heads
),
"the 3rd dimensions of sampling_loc should be num_heads, "
,
"but now the 3rd dimension of sampling_loc is "
,
sampling_loc
.
size
(
2
),
", and num_heads is "
,
num_heads
,
"."
);
TORCH_CHECK
((
sampling_loc
.
size
(
3
)
==
num_levels
),
"the 4th dimensions of sampling_loc should be num_levels, "
,
"but now the 4th dimension of sampling_loc is "
,
sampling_loc
.
size
(
3
),
", and num_levels is "
,
num_levels
,
"."
);
TORCH_CHECK
(
sampling_loc
.
size
(
5
)
==
2
,
"the 6th dimensions of sampling_loc should be 2, got "
,
sampling_loc
.
size
(
5
),
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
0
)
==
batch_size
),
"the 1st dimensions of attn_weight should be batch_size, "
,
"but now the 1st dimension of attn_weight is "
,
attn_weight
.
size
(
0
),
", and batch_size is "
,
batch_size
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
1
)
==
num_queries
),
"the 2nd dimensions of attn_weight should be num_queries, "
,
"but now the 2nd dimension of attn_weight is "
,
attn_weight
.
size
(
1
),
", and num_queries is "
,
num_queries
,
"."
);
TORCH_CHECK
((
attn_weight
.
size
(
2
)
==
num_heads
),
auto
grad_output_dim4
=
"the 3rd dimensions of attn_weight should be num_heads, "
,
grad_output
.
view
({
batch_size
,
num_queries
,
num_heads
,
channels
});
"but now the 3rd dimension of attn_weight is "
,
// auto grad_output_dim4 = grad_output.view({batch_size, num_queries,
attn_weight
.
size
(
2
),
", and num_heads is "
,
num_heads
,
"."
);
// num_heads, channels}).detach();
TORCH_CHECK
((
attn_weight
.
size
(
3
)
==
num_levels
),
INITIAL_MLU_PARAM_WITH_TENSOR
(
value
);
"the 4th dimensions of attn_weight should be num_levels, "
,
INITIAL_MLU_PARAM_WITH_TENSOR
(
spatial_shapes_int
);
"but now the 4th dimension of attn_weight is "
,
INITIAL_MLU_PARAM_WITH_TENSOR
(
level_start_index_int
);
attn_weight
.
size
(
3
),
", and num_levels is "
,
num_levels
,
"."
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
sampling_loc
);
TORCH_CHECK
((
attn_weight
.
size
(
4
)
==
num_points
),
INITIAL_MLU_PARAM_WITH_TENSOR
(
attn_weight
);
"the 5th dimensions of attn_weight should be num_points, "
,
INITIAL_MLU_PARAM_WITH_TENSOR
(
grad_output_dim4
);
"but now the 5th dimension of attn_weight is "
,
// INITIAL_MLU_PARAM_WITH_TENSOR(grad_output);
attn_weight
.
size
(
4
),
", and num_points is "
,
num_points
,
"."
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
grad_value
);
INITIAL_MLU_PARAM_WITH_TENSOR
(
grad_sampling_loc
);
TORCH_CHECK
((
grad_output
.
size
(
0
)
==
batch_size
),
INITIAL_MLU_PARAM_WITH_TENSOR
(
grad_attn_weight
);
"the 1st dimensions of grad_output should be batch_size, "
,
"but now the 1st dimension of grad_output is "
,
mluOpMsDeformAttnBackward
(
grad_output
.
size
(
0
),
", and batch_size is "
,
batch_size
,
"."
);
handle
,
value_desc
.
desc
(),
value_ptr
,
spatial_shapes_int_desc
.
desc
(),
TORCH_CHECK
((
grad_output
.
size
(
1
)
==
num_queries
),
spatial_shapes_int_ptr
,
level_start_index_int_desc
.
desc
(),
"the 2nd dimensions of grad_output should be num_queries, "
,
level_start_index_int_ptr
,
sampling_loc_desc
.
desc
(),
sampling_loc_ptr
,
"but now the 2nd dimension of grad_output is "
,
attn_weight_desc
.
desc
(),
attn_weight_ptr
,
grad_output_dim4_desc
.
desc
(),
grad_output
.
size
(
1
),
", and num_queries is "
,
num_queries
,
"."
);
grad_output_dim4_ptr
,
im2col_step
,
grad_value_desc
.
desc
(),
grad_value_ptr
,
TORCH_CHECK
(
grad_sampling_loc_desc
.
desc
(),
grad_sampling_loc_ptr
,
(
grad_output
.
size
(
2
)
==
num_heads
*
channels
),
grad_attn_weight_desc
.
desc
(),
grad_attn_weight_ptr
);
"the 3rd dimensions of grad_output should be num_heads * channels, "
,
"but now the 3rd dimension of grad_output is "
,
grad_output
.
size
(
2
),
return
;
", and num_heads * channels is "
,
num_heads
*
channels
,
"."
);
}
// check zero element
TORCH_CHECK
(
batch_size
!=
0
,
"The batch_size is zero."
);
TORCH_CHECK
(
channels
!=
0
,
"The channels is zero."
);
TORCH_CHECK
(
num_keys
!=
0
,
"The num_keys is zero."
);
TORCH_CHECK
(
num_heads
!=
0
,
"The num_heads is zero."
);
TORCH_CHECK
(
num_queries
!=
0
,
"The num_queries is zero."
);
if
(
num_levels
==
0
||
num_points
==
0
)
{
return
;
}
// calculate task dimension
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
policyFuncBackward
(
batch_size
,
num_queries
,
num_heads
,
num_levels
,
&
k_type
,
&
k_dim
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// get ptr of tensors
auto
value_impl
=
torch_mlu
::
getMluTensorImpl
(
value
);
auto
value_ptr
=
value_impl
->
cnnlMalloc
();
auto
spatial_shapes_impl
=
torch_mlu
::
getMluTensorImpl
(
spatial_shapes
);
auto
spatial_shapes_ptr
=
spatial_shapes_impl
->
cnnlMalloc
();
auto
level_start_index_impl
=
torch_mlu
::
getMluTensorImpl
(
level_start_index
);
auto
level_start_index_ptr
=
level_start_index_impl
->
cnnlMalloc
();
auto
sampling_loc_impl
=
torch_mlu
::
getMluTensorImpl
(
sampling_loc
);
auto
sampling_loc_ptr
=
sampling_loc_impl
->
cnnlMalloc
();
auto
attn_weight_impl
=
torch_mlu
::
getMluTensorImpl
(
attn_weight
);
auto
attn_weight_ptr
=
attn_weight_impl
->
cnnlMalloc
();
auto
grad_output_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_output
);
auto
grad_output_ptr
=
grad_output_impl
->
cnnlMalloc
();
auto
grad_value_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_value
);
auto
grad_value_ptr
=
grad_value_impl
->
cnnlMalloc
();
auto
grad_sampling_loc_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_sampling_loc
);
auto
grad_sampling_loc_ptr
=
grad_sampling_loc_impl
->
cnnlMalloc
();
auto
grad_attn_weight_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_attn_weight
);
auto
grad_attn_weight_ptr
=
grad_attn_weight_impl
->
cnnlMalloc
();
// get comput dtype of input
cnrtDataType_t
data_type
=
torch_mlu
::
toCnrtDtype
(
value
.
dtype
());
// launch kernel
Tensor
ms_deform_attn_mlu_forward
(
const
Tensor
&
value
,
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernelMsDeformAttnBackward<<<"
<<
k_dim
.
x
const
Tensor
&
spatial_shapes
,
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
const
Tensor
&
attn_weight
,
const
int
im2col_step
)
{
return
MsDeformAttnForwardLauncher
(
value
,
spatial_shapes
,
level_start_index
,
sampling_loc
,
attn_weight
,
im2col_step
);
}
KernelMsDeformAttnBackward
(
void
ms_deform_attn_mlu_backward
(
k_dim
,
k_type
,
queue
,
data_type
,
(
float
*
)
value_ptr
,
const
Tensor
&
value
,
const
Tensor
&
spatial_shapes
,
(
int32_t
*
)
spatial_shapes_ptr
,
(
int32_t
*
)
level_start_index_ptr
,
const
Tensor
&
level_start_index
,
const
Tensor
&
sampling_loc
,
(
float
*
)
sampling_loc_ptr
,
(
float
*
)
attn_weight_ptr
,
const
Tensor
&
attn_weight
,
const
Tensor
&
grad_output
,
Tensor
&
grad_value
,
(
float
*
)
grad_output_ptr
,
batch_size
,
num_keys
,
num_heads
,
channels
,
Tensor
&
grad_sampling_loc
,
Tensor
&
grad_attn_weight
,
num_levels
,
num_queries
,
num_points
,
(
float
*
)
grad_value_ptr
,
const
int
im2col_step
)
{
(
float
*
)
grad_sampling_loc_ptr
,
(
float
*
)
grad_attn_weight_ptr
);
return
MsDeformAttnBackwardLauncher
(
value
,
spatial_shapes
,
level_start_index
,
sampling_loc
,
attn_weight
,
grad_output
,
grad_value
,
grad_sampling_loc
,
grad_attn_weight
,
im2col_step
);
}
}
Tensor
ms_deform_attn_impl_forward
(
const
Tensor
&
value
,
Tensor
ms_deform_attn_impl_forward
(
const
Tensor
&
value
,
...
@@ -416,5 +124,6 @@ void ms_deform_attn_impl_backward(
...
@@ -416,5 +124,6 @@ void ms_deform_attn_impl_backward(
REGISTER_DEVICE_IMPL
(
ms_deform_attn_impl_forward
,
MLU
,
REGISTER_DEVICE_IMPL
(
ms_deform_attn_impl_forward
,
MLU
,
ms_deform_attn_mlu_forward
);
ms_deform_attn_mlu_forward
);
REGISTER_DEVICE_IMPL
(
ms_deform_attn_impl_backward
,
MLU
,
REGISTER_DEVICE_IMPL
(
ms_deform_attn_impl_backward
,
MLU
,
ms_deform_attn_mlu_backward
);
ms_deform_attn_mlu_backward
);
mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp
View file @
91da9643
...
@@ -10,123 +10,35 @@
...
@@ -10,123 +10,35 @@
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "mlu_common_helper.h"
#include "pytorch_mlu_helper.hpp"
void
KernelNms
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
data_type_input
,
const
void
*
boxes_ptr
,
const
void
*
scores_ptr
,
const
int
input_num_boxes
,
const
int
max_output_boxes
,
const
float
iou_threshold
,
const
float
offset
,
void
*
workspace_ptr
,
void
*
output_size_ptr
,
void
*
output_ptr
);
int
selectUnionType
(
uint32_t
use_job
,
int
box_num_per_core
)
{
// the box_num_per_core should be at least 256, otherwise the real IO
// bandwidth would be very low
while
(
box_num_per_core
<
256
&&
use_job
>=
4
)
{
box_num_per_core
*=
2
;
use_job
/=
2
;
}
return
use_job
;
}
static
cnnlStatus_t
policyFunc
(
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
,
int
&
core_num_per_class
,
const
int
input_box_num
)
{
uint32_t
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
uint32_t
cluster_number
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
uint32_t
job_limit
=
getJobLimitCapability
();
uint32_t
core_number
=
job_limit
;
int
box_num_per_core
=
(
input_box_num
+
core_number
-
1
)
/
core_number
;
int
use_job
=
selectUnionType
(
job_limit
,
box_num_per_core
);
// initiate k_type as Union1
k_dim
->
x
=
core_dim
;
k_dim
->
y
=
1
;
k_dim
->
z
=
1
;
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
switch
(
job_limit
)
{
case
CN_KERNEL_CLASS_BLOCK
:
case
CN_KERNEL_CLASS_UNION
:
case
CN_KERNEL_CLASS_UNION2
:
case
CN_KERNEL_CLASS_UNION4
:
case
CN_KERNEL_CLASS_UNION8
:
case
CN_KERNEL_CLASS_UNION16
:
{
if
(
use_job
<
4
)
{
k_dim
->
x
=
1
;
*
k_type
=
CNRT_FUNC_TYPE_BLOCK
;
}
else
if
(
use_job
==
4
)
{
k_dim
->
x
=
core_dim
;
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
}
else
{
k_dim
->
x
=
use_job
;
*
k_type
=
(
cnrtFunctionType_t
)
use_job
;
}
};
break
;
default:
LOG
(
WARNING
)
<<
"[cnnlNms_v2]: got unsupported job limit number."
<<
" Use default CN_KERNEL_CLASS_UNION1 with UNION1 task."
;
}
return
CNNL_STATUS_SUCCESS
;
}
Tensor
NMSMLUKernelLauncher
(
Tensor
boxes
,
Tensor
scores
,
float
iou_threshold
,
Tensor
NMSMLUKernelLauncher
(
Tensor
boxes
,
Tensor
scores
,
float
iou_threshold
,
int
offset
)
{
int
offset
)
{
// dimension parameters check
TORCH_CHECK
(
boxes
.
dim
()
==
2
,
"boxes should be a 2d tensor, got "
,
boxes
.
dim
(),
"D"
);
TORCH_CHECK
(
boxes
.
size
(
1
)
==
4
,
"boxes should have 4 elements in dimension 1, got "
,
boxes
.
size
(
1
));
TORCH_CHECK
(
scores
.
dim
()
==
1
,
"scores should be a 1d tensor, got "
,
scores
.
dim
(),
"D"
);
// data type check
TORCH_CHECK
(
boxes
.
scalar_type
()
==
scores
.
scalar_type
(),
"boxes should have the same type as scores"
);
TORCH_CHECK
(
boxes
.
scalar_type
()
==
at
::
kFloat
||
boxes
.
scalar_type
()
==
at
::
kHalf
,
"data type of boxes should be Float or Half, got "
,
boxes
.
scalar_type
());
if
(
boxes
.
numel
()
==
0
)
{
if
(
boxes
.
numel
()
==
0
)
{
return
at
::
empty
({
0
},
boxes
.
options
().
dtype
(
at
::
kLong
));
return
at
::
empty
({
0
},
boxes
.
options
().
dtype
(
at
::
kLong
));
}
}
int
input_num_boxes
=
boxes
.
size
(
0
);
int
max_output_boxes
=
boxes
.
size
(
0
);
int
max_output_boxes
=
boxes
.
size
(
0
);
cnrtDataType_t
data_type_input
=
torch_mlu
::
toCnrtDtype
(
boxes
.
dtype
());
cnrtDim3_t
k_dim
;
cnrtJobType_t
k_type
;
int
core_num_per_class
;
policyFunc
(
&
k_dim
,
&
k_type
,
core_num_per_class
,
input_num_boxes
);
// transpose boxes (n, 4) to (4, n) for better performance
// transpose boxes (n, 4) to (4, n) for better performance
auto
boxes_t
=
boxes
.
transpose
(
0
,
1
);
auto
boxes_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
boxes
);
auto
boxes_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
boxes_t
);
auto
scores_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
scores
);
auto
scores_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
scores
);
auto
output
=
at
::
empty
({
max_output_boxes
},
boxes
.
options
().
dtype
(
at
::
k
Long
));
auto
output
=
at
::
empty
({
max_output_boxes
},
boxes
.
options
().
dtype
(
at
::
k
Int
));
auto
output_size
=
at
::
empty
({
1
},
scores
.
options
().
dtype
(
at
::
kInt
));
auto
output_size
=
at
::
empty
({
1
},
scores
.
options
().
dtype
(
at
::
kInt
));
MluOpTensorDescriptor
boxes_desc
,
scores_desc
,
output_desc
;
boxes_desc
.
set
(
boxes_
);
scores_desc
.
set
(
scores_
);
output_desc
.
set
(
output
);
// workspace
// workspace
const
int
info_num
=
5
;
// x1, x2, y1, y2 and score
size_t
workspace_size
=
0
;
size_t
space_size
=
0
;
auto
handle
=
mluOpGetCurrentHandle
();
if
(
boxes
.
scalar_type
()
==
at
::
kHalf
)
{
TORCH_MLUOP_CHECK
(
mluOpGetNmsWorkspaceSize
(
space_size
=
input_num_boxes
*
sizeof
(
int16_t
)
*
info_num
+
sizeof
(
float
);
handle
,
boxes_desc
.
desc
(),
scores_desc
.
desc
(),
&
workspace_size
));
}
else
{
auto
workspace
=
at
::
empty
(
workspace_size
,
boxes
.
options
().
dtype
(
at
::
kByte
));
space_size
=
input_num_boxes
*
sizeof
(
float
)
*
info_num
+
sizeof
(
float
);
}
#if __BANG_ARCH__ > 370
int
cluster_num
=
getCoreNumOfJobLimitCapability
()
/
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
space_size
+=
cluster_number
*
sizeof
(
float
)
*
7
;
#endif
auto
workspace
=
at
::
empty
(
space_size
,
boxes
.
options
().
dtype
(
at
::
kByte
));
// get compute queue
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
auto
boxes_impl
=
torch_mlu
::
getMluTensorImpl
(
boxes_
);
auto
boxes_impl
=
torch_mlu
::
getMluTensorImpl
(
boxes_
);
auto
boxes_ptr
=
boxes_impl
->
cnnlMalloc
();
auto
boxes_ptr
=
boxes_impl
->
cnnlMalloc
();
auto
scores_impl
=
torch_mlu
::
getMluTensorImpl
(
scores_
);
auto
scores_impl
=
torch_mlu
::
getMluTensorImpl
(
scores_
);
...
@@ -138,14 +50,32 @@ Tensor NMSMLUKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold,
...
@@ -138,14 +50,32 @@ Tensor NMSMLUKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold,
auto
output_size_impl
=
torch_mlu
::
getMluTensorImpl
(
output_size
);
auto
output_size_impl
=
torch_mlu
::
getMluTensorImpl
(
output_size
);
auto
output_size_ptr
=
output_size_impl
->
cnnlMalloc
();
auto
output_size_ptr
=
output_size_impl
->
cnnlMalloc
();
uint32_t
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
// nms desc
CNLOG
(
INFO
)
<<
"Launch Kernel MLUUnionX NMS<<<Union"
<<
k_type
/
core_dim
mluOpNmsDescriptor_t
nms_desc
;
<<
", "
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
const
mluOpNmsBoxPointMode_t
box_mode
=
(
mluOpNmsBoxPointMode_t
)
0
;
KernelNms
(
k_dim
,
k_type
,
queue
,
data_type_input
,
boxes_ptr
,
scores_ptr
,
const
mluOpNmsOutputMode_t
output_mode
=
(
mluOpNmsOutputMode_t
)
0
;
input_num_boxes
,
max_output_boxes
,
iou_threshold
,
offset
,
const
mluOpNmsAlgo_t
algo
=
(
mluOpNmsAlgo_t
)
0
;
workspace_ptr
,
output_size_ptr
,
output_ptr
);
const
mluOpNmsMethodMode_t
method_mode
=
(
mluOpNmsMethodMode_t
)
0
;
const
float
soft_nms_sigma
=
0.0
;
const
float
confidence_threshold
=
0.0
;
const
int
input_layout
=
0
;
const
bool
pad_to_max_output_size
=
false
;
const
int
max_output_size
=
max_output_boxes
;
TORCH_MLUOP_CHECK
(
mluOpCreateNmsDescriptor
(
&
nms_desc
));
TORCH_MLUOP_CHECK
(
mluOpSetNmsDescriptor
(
nms_desc
,
box_mode
,
output_mode
,
algo
,
method_mode
,
iou_threshold
,
soft_nms_sigma
,
max_output_size
,
confidence_threshold
,
(
float
)
offset
,
input_layout
,
pad_to_max_output_size
));
TORCH_MLUOP_CHECK
(
mluOpNms
(
handle
,
nms_desc
,
boxes_desc
.
desc
(),
boxes_ptr
,
scores_desc
.
desc
(),
scores_ptr
,
workspace_ptr
,
workspace_size
,
output_desc
.
desc
(),
output_ptr
,
output_size_ptr
));
TORCH_MLUOP_CHECK
(
mluOpDestroyNmsDescriptor
(
nms_desc
));
int
output_num
=
*
static_cast
<
int
*>
(
output_size
.
cpu
().
data_ptr
());
int
output_num
=
*
static_cast
<
int
*>
(
output_size
.
cpu
().
data_ptr
());
return
output
.
slice
(
0
,
0
,
output_num
);
auto
ret
=
output
.
to
(
boxes
.
options
().
dtype
(
at
::
kLong
));
return
ret
.
slice
(
0
,
0
,
output_num
);
}
}
Tensor
nms_mlu
(
Tensor
boxes
,
Tensor
scores
,
float
iou_threshold
,
int
offset
)
{
Tensor
nms_mlu
(
Tensor
boxes
,
Tensor
scores
,
float
iou_threshold
,
int
offset
)
{
...
...
mmcv/ops/csrc/pytorch/mlu/nms_rotated_mlu.cpp
0 → 100644
View file @
91da9643
/*************************************************************************
* Copyright (C) 2021 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "mlu_common_helper.h"
Tensor
nms_rotated_mlu
(
Tensor
boxes
,
Tensor
scores
,
float
iou_threshold
)
{
if
(
boxes
.
numel
()
==
0
)
{
return
at
::
empty
({
0
},
boxes
.
options
().
dtype
(
at
::
kLong
));
}
int
boxes_num
=
boxes
.
size
(
0
);
auto
boxes_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
boxes
);
auto
scores_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
scores
);
auto
output
=
at
::
empty
({
boxes_num
},
boxes
.
options
().
dtype
(
at
::
kInt
));
auto
output_size
=
at
::
empty
({
1
},
scores
.
options
().
dtype
(
at
::
kInt
));
MluOpTensorDescriptor
boxes_desc
,
scores_desc
,
output_desc
;
boxes_desc
.
set
(
boxes_
);
scores_desc
.
set
(
scores_
);
output_desc
.
set
(
output
);
// workspace
size_t
workspace_size
=
0
;
auto
handle
=
mluOpGetCurrentHandle
();
TORCH_MLUOP_CHECK
(
mluOpGetNmsRotatedWorkspaceSize
(
handle
,
boxes_desc
.
desc
(),
&
workspace_size
));
auto
workspace
=
at
::
empty
(
workspace_size
,
boxes
.
options
().
dtype
(
at
::
kByte
));
auto
boxes_impl
=
torch_mlu
::
getMluTensorImpl
(
boxes_
);
auto
boxes_ptr
=
boxes_impl
->
cnnlMalloc
();
auto
scores_impl
=
torch_mlu
::
getMluTensorImpl
(
scores_
);
auto
scores_ptr
=
scores_impl
->
cnnlMalloc
();
auto
workspace_impl
=
torch_mlu
::
getMluTensorImpl
(
workspace
);
auto
workspace_ptr
=
workspace_impl
->
cnnlMalloc
();
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output
);
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
auto
output_size_impl
=
torch_mlu
::
getMluTensorImpl
(
output_size
);
auto
output_size_ptr
=
output_size_impl
->
cnnlMalloc
();
TORCH_MLUOP_CHECK
(
mluOpNmsRotated
(
handle
,
iou_threshold
,
boxes_desc
.
desc
(),
boxes_ptr
,
scores_desc
.
desc
(),
scores_ptr
,
workspace_ptr
,
workspace_size
,
output_desc
.
desc
(),
output_ptr
,
(
int
*
)
output_size_ptr
));
int
output_num
=
*
static_cast
<
int
*>
(
output_size
.
cpu
().
data_ptr
());
auto
ret
=
output
.
to
(
boxes
.
options
().
dtype
(
at
::
kLong
));
return
ret
.
slice
(
0
,
0
,
output_num
);
}
mmcv/ops/csrc/pytorch/mlu/psamask_mlu.cpp
View file @
91da9643
...
@@ -9,136 +9,7 @@
...
@@ -9,136 +9,7 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
*************************************************************************/
#include <algorithm>
#include "mlu_common_helper.h"
#include "psamask_utils.hpp"
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
#define COMPUTE_COUNT_ALIGN 64
void
KernelPsamaskForward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
void
*
x
,
void
*
y
,
const
PsamaskType
psa_type
,
const
DimPartitionType
core_partition
,
const
DimPartitionType
cluster_partition
,
const
int
batch
,
const
int
h_feature
,
const
int
w_feature
,
const
int
h_mask
,
const
int
w_mask
,
const
int
x_c
,
const
int
y_c
,
const
int
half_h_mask
,
const
int
half_w_mask
,
const
int
n_per_core
,
const
int
h_per_core
,
const
int
n_per_cluster
,
const
int
h_per_cluster
,
const
int
limit_n_seg
,
const
int
limit_h_seg
,
const
int
limit_w_seg
);
void
KernelPsamaskBackward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
void
*
dy
,
void
*
dx
,
const
PsamaskType
psa_type
,
const
DimPartitionType
core_partition
,
const
DimPartitionType
cluster_partition
,
const
int
batch
,
const
int
h_feature
,
const
int
w_feature
,
const
int
h_mask
,
const
int
w_mask
,
const
int
dx_c
,
const
int
dy_c
,
const
int
half_h_mask
,
const
int
half_w_mask
,
const
int
n_per_core
,
const
int
h_per_core
,
const
int
n_per_cluster
,
const
int
h_per_cluster
,
const
int
limit_n_seg
,
const
int
limit_h_seg
,
const
int
limit_w_seg
);
namespace
{
void
policyFunc
(
cnrtDim3_t
*
k_dim_ptr
,
cnrtFunctionType_t
*
f_type_ptr
,
PartitionSeg
*
partition_ptr
,
const
int
n
,
const
int
h_feature
)
{
unsigned
int
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
unsigned
int
cluster_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
unsigned
int
use_cluster_num
=
cluster_num
;
unsigned
int
use_core_num
=
core_dim
;
if
(
n
>=
cluster_num
||
n
>=
h_feature
)
{
partition_ptr
->
cluster_partition
=
PARTITION_N
;
partition_ptr
->
n_per_cluster
=
(
n
+
cluster_num
-
1
)
/
cluster_num
;
partition_ptr
->
h_per_cluster
=
h_feature
;
use_cluster_num
=
(
n
+
partition_ptr
->
n_per_cluster
-
1
)
/
partition_ptr
->
n_per_cluster
;
}
else
{
partition_ptr
->
cluster_partition
=
PARTITION_H
;
partition_ptr
->
h_per_cluster
=
(
h_feature
+
cluster_num
-
1
)
/
cluster_num
;
partition_ptr
->
n_per_cluster
=
n
;
use_cluster_num
=
(
h_feature
+
partition_ptr
->
h_per_cluster
-
1
)
/
partition_ptr
->
h_per_cluster
;
}
if
(
partition_ptr
->
n_per_cluster
>=
core_dim
||
partition_ptr
->
n_per_cluster
>=
partition_ptr
->
h_per_cluster
)
{
partition_ptr
->
core_partition
=
PARTITION_N
;
partition_ptr
->
n_per_core
=
(
partition_ptr
->
n_per_cluster
+
core_dim
-
1
)
/
core_dim
;
partition_ptr
->
h_per_core
=
partition_ptr
->
h_per_cluster
;
use_core_num
=
(
partition_ptr
->
n_per_cluster
+
partition_ptr
->
n_per_core
-
1
)
/
partition_ptr
->
n_per_core
;
}
else
{
partition_ptr
->
core_partition
=
PARTITION_H
;
partition_ptr
->
h_per_core
=
(
partition_ptr
->
h_per_cluster
+
core_dim
-
1
)
/
core_dim
;
partition_ptr
->
n_per_core
=
partition_ptr
->
n_per_cluster
;
use_core_num
=
(
partition_ptr
->
h_per_cluster
+
partition_ptr
->
h_per_core
-
1
)
/
partition_ptr
->
h_per_core
;
}
*
k_dim_ptr
=
{
core_dim
,
use_cluster_num
,
1
};
}
}
// namespace
bool
findLimit
(
const
int
shape_core_n
,
const
int
shape_core_h
,
const
int
shape_core_w
,
const
int
shape_core_ci
,
const
int
shape_core_co
,
int
*
limit_n_seg_ptr
,
int
*
limit_h_seg_ptr
,
int
*
limit_w_seg_ptr
,
const
int
psa_type
)
{
const
bool
need_temp
=
psa_type
==
1
;
const
int
input_bytes
=
sizeof
(
float
);
int
limit_n_seg
=
shape_core_n
;
int
limit_h_seg
=
shape_core_h
;
int
limit_w_seg
=
shape_core_w
;
const
int
max_nram_size
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrNramSizePerMcore
);
const
int
align_base_128
=
NFU_ALIGN_SIZE
/
input_bytes
;
const
int
align_base_64
=
COMPUTE_COUNT_ALIGN
/
input_bytes
;
const
int
align_co
=
CEIL_ALIGN
(
shape_core_co
,
align_base_64
);
const
int
align_w
=
CEIL_ALIGN
(
shape_core_w
,
align_base_64
);
const
int
align_hw
=
CEIL_ALIGN
(
shape_core_h
*
shape_core_w
,
align_base_64
);
const
int
max_num
=
max_nram_size
/
input_bytes
;
int
n_limit
=
max_num
/
(
CEIL_ALIGN
(
shape_core_h
*
shape_core_w
*
shape_core_ci
,
align_base_128
)
+
align_hw
*
align_co
*
(
1
+
need_temp
));
if
(
n_limit
>
0
)
{
n_limit
=
std
::
min
(
n_limit
,
shape_core_n
);
limit_n_seg
=
n_limit
;
}
else
{
int
h_limit
=
max_num
/
(
CEIL_ALIGN
(
shape_core_w
*
shape_core_ci
,
align_base_128
)
+
align_w
*
align_co
*
(
1
+
need_temp
));
if
(
h_limit
>
0
)
{
h_limit
=
std
::
min
(
h_limit
,
shape_core_h
);
limit_h_seg
=
h_limit
;
limit_n_seg
=
1
;
}
else
{
int
w_limit
=
max_num
/
(
CEIL_ALIGN
(
shape_core_ci
,
align_base_128
)
+
CEIL_ALIGN
(
align_co
,
align_base_128
)
*
(
1
+
need_temp
));
if
(
w_limit
>
0
&&
w_limit
>=
(
COMPUTE_COUNT_ALIGN
/
input_bytes
))
{
w_limit
=
std
::
min
(
w_limit
,
shape_core_w
);
w_limit
=
w_limit
/
(
COMPUTE_COUNT_ALIGN
/
input_bytes
)
*
(
COMPUTE_COUNT_ALIGN
/
input_bytes
);
limit_w_seg
=
w_limit
;
limit_h_seg
=
1
;
limit_n_seg
=
1
;
}
else
{
CNLOG
(
INFO
)
<<
"The size of input channel is too large."
;
return
false
;
}
}
}
*
limit_n_seg_ptr
=
limit_n_seg
;
*
limit_h_seg_ptr
=
limit_h_seg
;
*
limit_w_seg_ptr
=
limit_w_seg
;
return
true
;
}
void
PSAMaskForwardMLUKernelLauncher
(
const
int
psa_type
,
const
Tensor
x
,
void
PSAMaskForwardMLUKernelLauncher
(
const
int
psa_type
,
const
Tensor
x
,
Tensor
y
,
const
int
num_
,
Tensor
y
,
const
int
num_
,
...
@@ -146,39 +17,7 @@ void PSAMaskForwardMLUKernelLauncher(const int psa_type, const Tensor x,
...
@@ -146,39 +17,7 @@ void PSAMaskForwardMLUKernelLauncher(const int psa_type, const Tensor x,
const
int
h_mask
,
const
int
w_mask
,
const
int
h_mask
,
const
int
w_mask
,
const
int
half_h_mask
,
const
int
half_h_mask
,
const
int
half_w_mask
)
{
const
int
half_w_mask
)
{
// params check
TORCH_CHECK
(
x
.
scalar_type
()
==
at
::
kFloat
,
"x type should be Float, got "
,
x
.
scalar_type
());
TORCH_CHECK
(
y
.
scalar_type
()
==
x
.
scalar_type
(),
"y should have the same type as x"
);
TORCH_CHECK
(
x
.
dim
()
==
4
,
"x should be a 4d tensor, got "
,
x
.
dim
(),
"D"
);
TORCH_CHECK
(
y
.
dim
()
==
4
,
"y should be a 4d tensor, got "
,
y
.
dim
(),
"D"
);
int
x_c
=
x
.
size
(
1
);
int
y_c
=
y
.
size
(
1
);
int
y_c
=
y
.
size
(
1
);
TORCH_CHECK
(
h_mask
*
w_mask
==
x_c
,
"channel of x should be the same as h_mask * w_mask"
);
TORCH_CHECK
(
h_feature
*
w_feature
==
y_c
,
"channel of y should be the same as h_feature * w_feature"
);
TORCH_CHECK
(
psa_type
==
0
||
psa_type
==
1
,
"psa_type only supports 'COLLECT' and 'DISTRIBUTE' currently"
);
if
(
x
.
numel
()
==
0
)
{
CNLOG
(
INFO
)
<<
"skip zero-element tensor"
;
return
;
}
cnrtFunctionType_t
k_type
=
CNRT_FUNC_TYPE_UNION1
;
cnrtDim3_t
k_dim
;
PartitionSeg
partition_info
;
policyFunc
(
&
k_dim
,
&
k_type
,
&
partition_info
,
num_
,
h_feature
);
int
n_limit_seg
,
h_limit_seg
,
w_limit_seg
;
bool
ret
=
findLimit
(
partition_info
.
n_per_core
,
partition_info
.
h_per_core
,
w_feature
,
x_c
,
y_c
,
&
n_limit_seg
,
&
h_limit_seg
,
&
w_limit_seg
,
psa_type
);
if
(
ret
!=
true
)
{
return
;
}
auto
memory_format
=
auto
memory_format
=
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
x
.
dim
());
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
x
.
dim
());
...
@@ -186,22 +25,18 @@ void PSAMaskForwardMLUKernelLauncher(const int psa_type, const Tensor x,
...
@@ -186,22 +25,18 @@ void PSAMaskForwardMLUKernelLauncher(const int psa_type, const Tensor x,
at
::
Tensor
y_tmp
=
at
::
Tensor
y_tmp
=
at
::
empty
({
num_
,
y_c
,
h_feature
,
w_feature
},
x
.
options
(),
memory_format
);
at
::
empty
({
num_
,
y_c
,
h_feature
,
w_feature
},
x
.
options
(),
memory_format
);
// get compute queue
MluOpTensorDescriptor
x_desc
,
y_desc
;
auto
queue
=
torch_mlu
::
getCurQueue
();
x_desc
.
set_with_layout
(
x_tensor
,
MLUOP_LAYOUT_NHWC
);
y_desc
.
set_with_layout
(
y_tmp
,
MLUOP_LAYOUT_NHWC
);
// get ptr of tensors
auto
handle
=
mluOpGetCurrentHandle
();
auto
x_impl
=
torch_mlu
::
getMluTensorImpl
(
x_tensor
);
auto
x_impl
=
torch_mlu
::
getMluTensorImpl
(
x_tensor
);
auto
x_ptr
=
x_impl
->
cnnlMalloc
();
auto
x_ptr
=
x_impl
->
cnnlMalloc
();
auto
y_impl
=
torch_mlu
::
getMluTensorImpl
(
y_tmp
);
auto
y_impl
=
torch_mlu
::
getMluTensorImpl
(
y_tmp
);
auto
y_ptr
=
y_impl
->
cnnlMalloc
();
auto
y_ptr
=
y_impl
->
cnnlMalloc
();
KernelPsamaskForward
(
TORCH_MLUOP_CHECK
(
mluOpPsamaskForward
(
handle
,
psa_type
,
x_desc
.
desc
(),
x_ptr
,
k_dim
,
k_type
,
queue
,
x_ptr
,
y_ptr
,
(
PsamaskType
)
psa_type
,
h_mask
,
w_mask
,
y_desc
.
desc
(),
y_ptr
));
partition_info
.
core_partition
,
partition_info
.
cluster_partition
,
num_
,
h_feature
,
w_feature
,
h_mask
,
w_mask
,
x_c
,
y_c
,
half_h_mask
,
half_w_mask
,
partition_info
.
n_per_core
,
partition_info
.
h_per_core
,
partition_info
.
n_per_cluster
,
partition_info
.
h_per_cluster
,
n_limit_seg
,
h_limit_seg
,
w_limit_seg
);
y
.
copy_
(
y_tmp
);
y
.
copy_
(
y_tmp
);
}
}
...
@@ -212,39 +47,7 @@ void PSAMaskBackwardMLUKernelLauncher(const int psa_type, const Tensor dy,
...
@@ -212,39 +47,7 @@ void PSAMaskBackwardMLUKernelLauncher(const int psa_type, const Tensor dy,
const
int
h_mask
,
const
int
w_mask
,
const
int
h_mask
,
const
int
w_mask
,
const
int
half_h_mask
,
const
int
half_h_mask
,
const
int
half_w_mask
)
{
const
int
half_w_mask
)
{
// params check
TORCH_CHECK
(
dy
.
scalar_type
()
==
at
::
kFloat
,
"dy type should be Float, got "
,
dy
.
scalar_type
());
TORCH_CHECK
(
dx
.
scalar_type
()
==
dy
.
scalar_type
(),
"dx should have the same type as dy"
);
TORCH_CHECK
(
dy
.
dim
()
==
4
,
"dy should be a 4d tensor, got "
,
dy
.
dim
(),
"D"
);
TORCH_CHECK
(
dx
.
dim
()
==
4
,
"dx should be a 4d tensor, got "
,
dx
.
dim
(),
"D"
);
int
dy_c
=
dy
.
size
(
1
);
int
dx_c
=
dx
.
size
(
1
);
int
dx_c
=
dx
.
size
(
1
);
TORCH_CHECK
(
h_feature
*
w_feature
==
dy_c
,
"channel of dy should be the same as h_feature * w_feature"
);
TORCH_CHECK
(
h_mask
*
w_mask
==
dx_c
,
"channel of dx should be the same as h_mask * w_mask"
);
TORCH_CHECK
(
psa_type
==
0
||
psa_type
==
1
,
"psa_type only supports 'COLLECT' and 'DISTRIBUTE' currently"
);
if
(
dx
.
numel
()
==
0
)
{
CNLOG
(
INFO
)
<<
"skip zero-element tensor"
;
return
;
}
cnrtFunctionType_t
k_type
=
CNRT_FUNC_TYPE_UNION1
;
cnrtDim3_t
k_dim
;
PartitionSeg
partition_info
;
policyFunc
(
&
k_dim
,
&
k_type
,
&
partition_info
,
num_
,
h_feature
);
int
n_limit_seg
,
h_limit_seg
,
w_limit_seg
;
bool
ret
=
findLimit
(
partition_info
.
n_per_core
,
partition_info
.
h_per_core
,
w_feature
,
dx_c
,
dy_c
,
&
n_limit_seg
,
&
h_limit_seg
,
&
w_limit_seg
,
psa_type
);
if
(
ret
!=
true
)
{
return
;
}
auto
memory_format
=
auto
memory_format
=
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
dy
.
dim
());
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
dy
.
dim
());
...
@@ -252,8 +55,11 @@ void PSAMaskBackwardMLUKernelLauncher(const int psa_type, const Tensor dy,
...
@@ -252,8 +55,11 @@ void PSAMaskBackwardMLUKernelLauncher(const int psa_type, const Tensor dy,
at
::
Tensor
dx_tmp
=
at
::
empty
({
num_
,
dx_c
,
h_feature
,
w_feature
},
at
::
Tensor
dx_tmp
=
at
::
empty
({
num_
,
dx_c
,
h_feature
,
w_feature
},
dy
.
options
(),
memory_format
);
dy
.
options
(),
memory_format
);
// get compute queue
MluOpTensorDescriptor
dy_desc
,
dx_tmp_desc
;
auto
queue
=
torch_mlu
::
getCurQueue
();
dy_desc
.
set_with_layout
(
dy_tensor
,
MLUOP_LAYOUT_NHWC
);
dx_tmp_desc
.
set_with_layout
(
dx_tmp
,
MLUOP_LAYOUT_NHWC
);
auto
handle
=
mluOpGetCurrentHandle
();
// get ptr of tensors
// get ptr of tensors
auto
dx_impl
=
torch_mlu
::
getMluTensorImpl
(
dx_tmp
);
auto
dx_impl
=
torch_mlu
::
getMluTensorImpl
(
dx_tmp
);
...
@@ -261,13 +67,9 @@ void PSAMaskBackwardMLUKernelLauncher(const int psa_type, const Tensor dy,
...
@@ -261,13 +67,9 @@ void PSAMaskBackwardMLUKernelLauncher(const int psa_type, const Tensor dy,
auto
dy_impl
=
torch_mlu
::
getMluTensorImpl
(
dy_tensor
);
auto
dy_impl
=
torch_mlu
::
getMluTensorImpl
(
dy_tensor
);
auto
dy_ptr
=
dy_impl
->
cnnlMalloc
();
auto
dy_ptr
=
dy_impl
->
cnnlMalloc
();
KernelPsamaskBackward
(
TORCH_MLUOP_CHECK
(
mluOpPsamaskBackward
(
handle
,
psa_type
,
dy_desc
.
desc
(),
k_dim
,
k_type
,
queue
,
dy_ptr
,
dx_ptr
,
(
PsamaskType
)
psa_type
,
dy_ptr
,
h_mask
,
w_mask
,
partition_info
.
core_partition
,
partition_info
.
cluster_partition
,
num_
,
dx_tmp_desc
.
desc
(),
dx_ptr
));
h_feature
,
w_feature
,
h_mask
,
w_mask
,
dx_c
,
dy_c
,
half_h_mask
,
half_w_mask
,
partition_info
.
n_per_core
,
partition_info
.
h_per_core
,
partition_info
.
n_per_cluster
,
partition_info
.
h_per_cluster
,
n_limit_seg
,
h_limit_seg
,
w_limit_seg
);
dx
.
copy_
(
dx_tmp
);
dx
.
copy_
(
dx_tmp
);
}
}
...
...
mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp
View file @
91da9643
...
@@ -9,26 +9,7 @@
...
@@ -9,26 +9,7 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "mlu_common_helper.h"
#include "pytorch_mlu_helper.hpp"
void
KernelRoiAlign
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
void
*
input
,
const
void
*
rois
,
const
int
channels
,
const
bool
aligned
,
const
int
pooled_height
,
const
int
pooled_width
,
const
int
input_height
,
const
int
input_width
,
const
int
sampling_ratio
,
const
float
spatial_scale
,
const
int
num_rois
,
void
*
output
);
void
KernelRoiAlignBackward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
dtype
,
const
void
*
grads
,
const
void
*
boxes
,
void
*
grads_image
,
const
int
boxes_num
,
const
int
hi
,
const
int
wi
,
const
int
c
,
const
int
no
,
const
int
ho
,
const
int
wo
,
const
float
spatial_scale
,
const
int
sampling_ratio
,
const
bool
aligned
);
void
ROIAlignForwardMLUKernelLauncher
(
Tensor
input
,
Tensor
rois
,
Tensor
output
,
void
ROIAlignForwardMLUKernelLauncher
(
Tensor
input
,
Tensor
rois
,
Tensor
output
,
Tensor
argmax_y
,
Tensor
argmax_x
,
Tensor
argmax_y
,
Tensor
argmax_x
,
...
@@ -36,17 +17,7 @@ void ROIAlignForwardMLUKernelLauncher(Tensor input, Tensor rois, Tensor output,
...
@@ -36,17 +17,7 @@ void ROIAlignForwardMLUKernelLauncher(Tensor input, Tensor rois, Tensor output,
float
spatial_scale
,
int
sampling_ratio
,
float
spatial_scale
,
int
sampling_ratio
,
int
pool_mode
,
bool
aligned
)
{
int
pool_mode
,
bool
aligned
)
{
// params check
// params check
TORCH_CHECK
(
input
.
scalar_type
()
==
at
::
kFloat
||
input
.
scalar_type
()
==
at
::
kHalf
,
"input type should be Float or Half, got "
,
input
.
scalar_type
());
TORCH_CHECK
(
rois
.
scalar_type
()
==
input
.
scalar_type
(),
"rois should have the same type as input"
);
TORCH_CHECK
(
input
.
dim
()
==
4
,
"input should be a 4d tensor, got "
,
input
.
dim
(),
"D"
);
TORCH_CHECK
(
rois
.
dim
()
==
2
,
"rois should be a 2d tensor, got "
,
rois
.
dim
(),
"D"
);
TORCH_CHECK
(
pool_mode
==
1
,
"pool_mode only supports 'avg' currently"
);
TORCH_CHECK
(
pool_mode
==
1
,
"pool_mode only supports 'avg' currently"
);
auto
memory_format
=
auto
memory_format
=
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
input
.
dim
());
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
input
.
dim
());
auto
input_tensor
=
auto
input_tensor
=
...
@@ -57,52 +28,56 @@ void ROIAlignForwardMLUKernelLauncher(Tensor input, Tensor rois, Tensor output,
...
@@ -57,52 +28,56 @@ void ROIAlignForwardMLUKernelLauncher(Tensor input, Tensor rois, Tensor output,
int
height
=
input
.
size
(
2
);
int
height
=
input
.
size
(
2
);
int
width
=
input
.
size
(
3
);
int
width
=
input
.
size
(
3
);
if
(
output
.
numel
()
==
0
)
{
auto
output_contiguous
=
output
=
at
::
zeros
({
num_rois
,
channels
,
aligned_height
,
aligned_width
},
input
.
options
());
return
;
}
at
::
Tensor
output_tmp
=
at
::
empty
({
num_rois
,
channels
,
aligned_height
,
aligned_width
},
at
::
empty
({
num_rois
,
channels
,
aligned_height
,
aligned_width
},
input
.
options
(),
memory_format
);
input
.
options
(),
memory_format
);
// get tensor impl
// get tensor impl
auto
self_impl
=
torch_mlu
::
getMluTensorImpl
(
input_tensor
);
auto
self_impl
=
torch_mlu
::
getMluTensorImpl
(
input_tensor
);
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois
);
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois
);
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output_
tmp
);
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output_
contiguous
);
// get compute queue
MluOpTensorDescriptor
input_desc
,
rois_desc
,
argmax_y_desc
,
argmax_x_desc
,
auto
queue
=
torch_mlu
::
getCurQueue
();
output_desc
;
input_desc
.
set_with_layout
(
input_tensor
,
MLUOP_LAYOUT_NHWC
);
rois_desc
.
set_with_layout
(
rois
,
MLUOP_LAYOUT_ARRAY
);
output_desc
.
set_with_layout
(
output_contiguous
,
MLUOP_LAYOUT_NHWC
);
// get the mlu ptr
// get the mlu ptr
auto
self_ptr
=
self_impl
->
cnnlMalloc
();
auto
self_ptr
=
self_impl
->
cnnlMalloc
();
auto
rois_ptr
=
rois_impl
->
cnnlMalloc
();
auto
rois_ptr
=
rois_impl
->
cnnlMalloc
();
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
cnrtJobType_t
k_type
=
CNRT_FUNC_TYPE_UNION1
;
mluOpRoiAlignForwardDescriptor_t
roialign_desc
;
cnrtDim3_t
k_dim
;
TORCH_MLUOP_CHECK
(
mluOpCreateRoiAlignForwardDescriptor
(
&
roialign_desc
));
k_dim
.
x
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
TORCH_MLUOP_CHECK
(
mluOpSetRoiAlignForwardDescriptor_v2
(
k_dim
.
y
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
roialign_desc
,
aligned_height
,
aligned_width
,
sampling_ratio
,
k_dim
.
z
=
1
;
spatial_scale
,
pool_mode
,
aligned
));
cnrtDataType_t
data_type
=
torch_mlu
::
toCnrtDtype
(
input
.
dtype
());
auto
handle
=
mluOpGetCurrentHandle
();
KernelRoiAlign
(
k_dim
,
k_type
,
queue
,
data_type
,
self_ptr
,
rois_ptr
,
channels
,
if
(
pool_mode
==
0
)
{
aligned
,
aligned_height
,
aligned_width
,
height
,
width
,
auto
argmax_y_contiguous
=
sampling_ratio
,
spatial_scale
,
num_rois
,
output_ptr
);
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
argmax_y
,
memory_format
);
auto
argmax_x_contiguous
=
output
.
copy_
(
output_tmp
);
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
argmax_x
,
memory_format
);
}
auto
argmax_x_impl
=
torch_mlu
::
getMluTensorImpl
(
argmax_x_contiguous
);
auto
argmax_y_impl
=
torch_mlu
::
getMluTensorImpl
(
argmax_y_contiguous
);
static
int
nearestPower2
(
int
x
)
{
auto
argmax_x_ptr
=
argmax_x_impl
->
cnnlMalloc
();
x
--
;
auto
argmax_y_ptr
=
argmax_y_impl
->
cnnlMalloc
();
x
|=
x
>>
1
;
argmax_y_desc
.
set_with_layout
(
argmax_x_contiguous
,
MLUOP_LAYOUT_NHWC
);
x
|=
x
>>
2
;
argmax_x_desc
.
set_with_layout
(
argmax_x_contiguous
,
MLUOP_LAYOUT_NHWC
);
x
|=
x
>>
4
;
TORCH_MLUOP_CHECK
(
mluOpRoiAlignForward_v2
(
x
|=
x
>>
8
;
handle
,
roialign_desc
,
input_desc
.
desc
(),
self_ptr
,
rois_desc
.
desc
(),
x
|=
x
>>
16
;
rois_ptr
,
output_desc
.
desc
(),
output_ptr
,
argmax_x_desc
.
desc
(),
x
++
;
argmax_x_ptr
,
argmax_y_desc
.
desc
(),
argmax_y_ptr
));
return
x
;
argmax_x
.
copy_
(
argmax_x_contiguous
);
argmax_y
.
copy_
(
argmax_y_contiguous
);
}
else
{
TORCH_MLUOP_CHECK
(
mluOpRoiAlignForward_v2
(
handle
,
roialign_desc
,
input_desc
.
desc
(),
self_ptr
,
rois_desc
.
desc
(),
rois_ptr
,
output_desc
.
desc
(),
output_ptr
,
NULL
,
NULL
,
NULL
,
NULL
));
}
TORCH_MLUOP_CHECK
(
mluOpDestroyRoiAlignForwardDescriptor
(
roialign_desc
));
output
.
copy_
(
output_contiguous
);
}
}
void
ROIAlignBackwardMLUKernelLauncher
(
Tensor
grad
,
Tensor
rois
,
void
ROIAlignBackwardMLUKernelLauncher
(
Tensor
grad
,
Tensor
rois
,
...
@@ -112,17 +87,7 @@ void ROIAlignBackwardMLUKernelLauncher(Tensor grad, Tensor rois,
...
@@ -112,17 +87,7 @@ void ROIAlignBackwardMLUKernelLauncher(Tensor grad, Tensor rois,
int
sampling_ratio
,
int
pool_mode
,
int
sampling_ratio
,
int
pool_mode
,
bool
aligned
)
{
bool
aligned
)
{
// params check
// params check
TORCH_CHECK
(
grad
.
scalar_type
()
==
at
::
kFloat
||
grad
.
scalar_type
()
==
at
::
kHalf
,
"grad type should be Float or Half, got "
,
grad
.
scalar_type
());
TORCH_CHECK
(
rois
.
scalar_type
()
==
grad
.
scalar_type
(),
"rois should have the same type as grad"
);
TORCH_CHECK
(
grad
.
dim
()
==
4
,
"grad should be a 4d tensor, got "
,
grad
.
dim
(),
"D"
);
TORCH_CHECK
(
rois
.
dim
()
==
2
,
"rois should be a 2d tensor, got "
,
rois
.
dim
(),
"D"
);
TORCH_CHECK
(
pool_mode
==
1
,
"pool_mode only supports 'avg' currently"
);
TORCH_CHECK
(
pool_mode
==
1
,
"pool_mode only supports 'avg' currently"
);
int
batch_size
=
grad_input
.
size
(
0
);
int
batch_size
=
grad_input
.
size
(
0
);
int
channels
=
grad_input
.
size
(
1
);
int
channels
=
grad_input
.
size
(
1
);
int
height
=
grad_input
.
size
(
2
);
int
height
=
grad_input
.
size
(
2
);
...
@@ -148,26 +113,40 @@ void ROIAlignBackwardMLUKernelLauncher(Tensor grad, Tensor rois,
...
@@ -148,26 +113,40 @@ void ROIAlignBackwardMLUKernelLauncher(Tensor grad, Tensor rois,
auto
grad_input_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_input_
);
auto
grad_input_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_input_
);
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois
);
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// get the mlu ptr
// get the mlu ptr
auto
grad_ptr
=
grad_impl
->
cnnlMalloc
();
auto
grad_ptr
=
grad_impl
->
cnnlMalloc
();
auto
rois_ptr
=
rois_impl
->
cnnlMalloc
();
auto
rois_ptr
=
rois_impl
->
cnnlMalloc
();
auto
grad_input_ptr
=
grad_input_impl
->
cnnlMalloc
();
auto
grad_input_ptr
=
grad_input_impl
->
cnnlMalloc
();
cnrtJobType_t
k_type
=
CNRT_FUNC_TYPE_UNION1
;
MluOpTensorDescriptor
grads_desc
,
rois_desc
,
argmax_y_desc
,
argmax_x_desc
,
int
need_core
=
nearestPower2
(
boxes_num
);
grad_input_desc
;
int
union_number
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
grads_desc
.
set_with_layout
(
grad_
,
MLUOP_LAYOUT_NHWC
);
uint32_t
dim_x
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
rois_desc
.
set_with_layout
(
rois
,
MLUOP_LAYOUT_ARRAY
);
uint32_t
dim_y
=
(
need_core
-
1
)
/
dim_x
+
1
;
grad_input_desc
.
set_with_layout
(
grad_input_
,
MLUOP_LAYOUT_NHWC
);
dim_y
=
(
dim_y
>
union_number
)
?
union_number
:
dim_y
;
cnrtDim3_t
k_dim
=
{
dim_x
,
dim_y
,
1
};
auto
handle
=
mluOpGetCurrentHandle
();
cnrtDataType_t
k_dtype
=
torch_mlu
::
toCnrtDtype
(
grad
.
dtype
());
if
(
pool_mode
==
0
)
{
auto
argmax_y_contiguous
=
KernelRoiAlignBackward
(
k_dim
,
k_type
,
queue
,
k_dtype
,
grad_ptr
,
rois_ptr
,
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
argmax_y
,
memory_format
);
grad_input_ptr
,
boxes_num
,
hi
,
wi
,
c
,
no
,
ho
,
wo
,
auto
argmax_x_contiguous
=
spatial_scale
,
sampling_ratio
,
aligned
);
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
argmax_x
,
memory_format
);
auto
argmax_x_impl
=
torch_mlu
::
getMluTensorImpl
(
argmax_x_contiguous
);
auto
argmax_y_impl
=
torch_mlu
::
getMluTensorImpl
(
argmax_y_contiguous
);
auto
argmax_x_ptr
=
argmax_x_impl
->
cnnlMalloc
();
auto
argmax_y_ptr
=
argmax_y_impl
->
cnnlMalloc
();
argmax_y_desc
.
set_with_layout
(
argmax_x_contiguous
,
MLUOP_LAYOUT_NHWC
);
argmax_x_desc
.
set_with_layout
(
argmax_x_contiguous
,
MLUOP_LAYOUT_NHWC
);
TORCH_MLUOP_CHECK
(
mluOpRoiAlignBackward_v2
(
handle
,
grads_desc
.
desc
(),
grad_ptr
,
rois_desc
.
desc
(),
rois_ptr
,
argmax_y_desc
.
desc
(),
argmax_x_ptr
,
argmax_y_desc
.
desc
(),
argmax_y_ptr
,
spatial_scale
,
sampling_ratio
,
aligned
,
pool_mode
,
grad_input_desc
.
desc
(),
grad_input_ptr
));
}
else
{
TORCH_MLUOP_CHECK
(
mluOpRoiAlignBackward_v2
(
handle
,
grads_desc
.
desc
(),
grad_ptr
,
rois_desc
.
desc
(),
rois_ptr
,
NULL
,
NULL
,
NULL
,
NULL
,
spatial_scale
,
sampling_ratio
,
aligned
,
pool_mode
,
grad_input_desc
.
desc
(),
grad_input_ptr
));
}
grad_input
.
copy_
(
grad_input_
);
grad_input
.
copy_
(
grad_input_
);
}
}
...
...
mmcv/ops/csrc/pytorch/mlu/roi_align_rotated_mlu.cpp
100755 → 100644
View file @
91da9643
...
@@ -9,37 +9,7 @@
...
@@ -9,37 +9,7 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "mlu_common_helper.h"
#include "pytorch_mlu_helper.hpp"
#include "roi_align_rotated_utils.hpp"
namespace
{
void
policyFunc
(
int
bin_num
,
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
)
{
unsigned
int
core_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
unsigned
int
cluster_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
k_dim
->
x
=
core_num
;
unsigned
int
use_cluster
=
(
bin_num
+
core_num
-
1
)
/
core_num
;
k_dim
->
y
=
use_cluster
>
cluster_num
?
cluster_num
:
use_cluster
;
k_dim
->
z
=
1
;
}
}
// namespace
void
KernelRoiAlignRotatedForward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
void
*
features
,
const
void
*
rois
,
void
*
output
,
const
int
batch
,
const
int
height
,
const
int
width
,
const
int
channel
,
const
int
rois_num
,
const
RoiAlignRotatedParams
roiAlignRotatedParams
);
void
KernelRoiAlignRotatedBackward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
void
*
top_grad
,
const
void
*
rois
,
void
*
bottom_grad
,
const
int
batch
,
const
int
height
,
const
int
width
,
const
int
channel
,
const
int
rois_num
,
const
RoiAlignRotatedParams
roiAlignRotatedParams
);
void
ROIAlignRotatedForwardMLUKernelLauncher
(
Tensor
input
,
Tensor
rois
,
void
ROIAlignRotatedForwardMLUKernelLauncher
(
Tensor
input
,
Tensor
rois
,
Tensor
output
,
int
pooled_height
,
Tensor
output
,
int
pooled_height
,
...
@@ -47,153 +17,70 @@ void ROIAlignRotatedForwardMLUKernelLauncher(Tensor input, Tensor rois,
...
@@ -47,153 +17,70 @@ void ROIAlignRotatedForwardMLUKernelLauncher(Tensor input, Tensor rois,
float
spatial_scale
,
float
spatial_scale
,
int
sampling_ratio
,
bool
aligned
,
int
sampling_ratio
,
bool
aligned
,
bool
clockwise
)
{
bool
clockwise
)
{
TORCH_CHECK
(((
input
.
scalar_type
()
==
output
.
scalar_type
())
&&
(
output
.
scalar_type
()
==
rois
.
scalar_type
())),
"data types of input, rois and output should be the same, "
,
"but now input type is "
,
input
.
scalar_type
(),
", rois type is "
,
rois
.
scalar_type
(),
", output type is "
,
output
.
scalar_type
(),
"."
);
TORCH_CHECK
(
(
input
.
scalar_type
()
==
at
::
kFloat
||
input
.
scalar_type
()
==
at
::
kHalf
),
"input type should be Float or Half, got "
,
input
.
scalar_type
(),
"."
);
TORCH_CHECK
(
input
.
dim
()
==
4
,
"input should be a 4d tensor, got "
,
input
.
dim
(),
"D."
);
TORCH_CHECK
(
rois
.
dim
()
==
2
,
"rois should be a 2d tensor, got "
,
rois
.
dim
(),
"D."
);
TORCH_CHECK
(
output
.
dim
()
==
4
,
"output should be a 4d tensor, got "
,
output
.
dim
(),
"D."
);
TORCH_CHECK
((
rois
.
size
(
0
)
==
output
.
size
(
0
)),
"the 1st dimensions of rois and output should be the same, "
,
"but now the 1st dimension of rois is "
,
rois
.
size
(
0
),
", and output is "
,
output
.
size
(
0
),
"."
);
TORCH_CHECK
((
input
.
size
(
1
)
==
output
.
size
(
1
)),
"the 2nd dimensions of input and output should be the same, "
,
"but now the 2nd dimension of input is "
,
input
.
size
(
1
),
", and output is "
,
output
.
size
(
1
),
"."
);
int
channel
=
input
.
size
(
1
);
int
width
=
input
.
size
(
3
);
int
height
=
input
.
size
(
2
);
int
batch
=
input
.
size
(
0
);
int
rois_nums
=
rois
.
size
(
0
);
cnrtDataType_t
d_type
=
torch_mlu
::
toCnrtDtype
(
input
.
dtype
());
// return if zero-elements
if
(
input
.
numel
()
==
0
)
{
CNLOG
(
INFO
)
<<
"Skip the zero-elements case."
;
return
;
}
RoiAlignRotatedParams
roiAlignRotatedParams
{
pooled_height
,
pooled_width
,
sampling_ratio
,
spatial_scale
,
aligned
,
clockwise
};
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
policyFunc
(
rois_nums
*
pooled_height
*
pooled_width
,
&
k_dim
,
&
k_type
);
auto
memory_format
=
auto
memory_format
=
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
input
.
dim
());
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
input
.
dim
());
auto
input_
tensor
=
auto
input_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
input
,
memory_format
);
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
input
,
memory_format
);
auto
rois_contiguous
=
at
::
Tensor
output_tmp
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
rois
,
rois
.
suggest_memory_format
());
at
::
empty
({
rois_nums
,
channel
,
pooled_height
,
pooled_width
},
auto
output_contiguous
=
input
.
options
()
,
memory_format
);
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
output
,
memory_format
);
// get compute queue
MluOpTensorDescriptor
input_desc
,
rois_desc
,
output_desc
;
auto
queue
=
torch_mlu
::
getCurQueue
();
input_desc
.
set_with_layout
(
input_
,
MLUOP_LAYOUT_NHWC
);
rois_desc
.
set
(
rois_contiguous
);
output_desc
.
set_with_layout
(
output_contiguous
,
MLUOP_LAYOUT_NHWC
);
// get ptr of tensors
// get ptr of tensors
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
input_
tensor
);
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
input_
);
auto
input_ptr
=
input_impl
->
cnnlMalloc
();
auto
input_ptr
=
input_impl
->
cnnlMalloc
();
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois
);
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois
_contiguous
);
auto
rois_ptr
=
rois_impl
->
cnnlMalloc
();
auto
rois_ptr
=
rois_impl
->
cnnlMalloc
();
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output_
tmp
);
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output_
contiguous
);
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
KernelRoiAlignRotatedForward
(
k_dim
,
k_type
,
queue
,
d_type
,
input_ptr
,
// get compute handle
rois_ptr
,
output_ptr
,
batch
,
height
,
width
,
auto
handle
=
mluOpGetCurrentHandle
();
channel
,
rois_nums
,
roiAlignRotatedParams
);
TORCH_MLUOP_CHECK
(
mluOpRoiAlignRotatedForward
(
output
.
copy_
(
output_tmp
);
handle
,
input_desc
.
desc
(),
input_ptr
,
rois_desc
.
desc
(),
rois_ptr
,
pooled_height
,
pooled_width
,
sampling_ratio
,
spatial_scale
,
aligned
,
clockwise
,
output_desc
.
desc
(),
output_ptr
));
output
.
copy_
(
output_contiguous
);
}
}
void
ROIAlignRotatedBackwardMLUKernelLauncher
(
void
ROIAlignRotatedBackwardMLUKernelLauncher
(
Tensor
top_grad
,
Tensor
rois
,
Tensor
bottom_grad
,
int
pooled_height
,
Tensor
top_grad
,
Tensor
rois
,
Tensor
bottom_grad
,
int
pooled_height
,
int
pooled_width
,
float
spatial_scale
,
int
sampling_ratio
,
bool
aligned
,
int
pooled_width
,
float
spatial_scale
,
int
sampling_ratio
,
bool
aligned
,
bool
clockwise
)
{
bool
clockwise
)
{
TORCH_CHECK
(((
top_grad
.
scalar_type
()
==
bottom_grad
.
scalar_type
())
&&
(
bottom_grad
.
scalar_type
()
==
rois
.
scalar_type
())),
"data types of top_grad, rois and bottom_grad should be "
,
"the same, but now top_grad type is "
,
top_grad
.
scalar_type
(),
", rois type is "
,
rois
.
scalar_type
(),
", bottom_grad type is "
,
bottom_grad
.
scalar_type
(),
"."
);
TORCH_CHECK
((
bottom_grad
.
scalar_type
()
==
at
::
kFloat
||
bottom_grad
.
scalar_type
()
==
at
::
kHalf
),
"Data type of bottom_grad should be Float ro Half, got "
,
bottom_grad
.
scalar_type
(),
"."
);
TORCH_CHECK
(
bottom_grad
.
dim
()
==
4
,
"bottom_grad should be a 4d tensor, got "
,
top_grad
.
dim
(),
"D."
);
TORCH_CHECK
(
rois
.
dim
()
==
2
,
"rois should be a 2d tensor, got "
,
rois
.
dim
(),
"D."
);
TORCH_CHECK
(
top_grad
.
dim
()
==
4
,
"top_grad should be a 4d tensor, got "
,
bottom_grad
.
dim
(),
"D."
);
TORCH_CHECK
((
rois
.
size
(
0
)
==
top_grad
.
size
(
0
)),
"the 1st dimensions of rois and top_grad should be the same, "
,
"but now the 1st dimension of rois is "
,
rois
.
size
(
0
),
", and top_grad is "
,
top_grad
.
size
(
0
),
"."
);
TORCH_CHECK
((
bottom_grad
.
size
(
1
)
==
top_grad
.
size
(
1
)),
"the 2nd dimensions of bottom_grad and top_grad should be "
,
"the same, but now the 2nd dimension of bottom_grad is "
,
bottom_grad
.
size
(
1
),
", and top_grad is "
,
top_grad
.
size
(
1
),
"."
);
int
channel
=
bottom_grad
.
size
(
1
);
int
width
=
bottom_grad
.
size
(
3
);
int
height
=
bottom_grad
.
size
(
2
);
int
batch
=
bottom_grad
.
size
(
0
);
int
rois_nums
=
rois
.
size
(
0
);
cnrtDataType_t
d_type
=
torch_mlu
::
toCnrtDtype
(
bottom_grad
.
dtype
());
// return if zero-elements
if
(
bottom_grad
.
numel
()
==
0
)
{
CNLOG
(
INFO
)
<<
"Skip the zero-elements case."
;
return
;
}
RoiAlignRotatedParams
roiAlignRotatedParams
{
pooled_height
,
pooled_width
,
sampling_ratio
,
spatial_scale
,
aligned
,
clockwise
};
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
policyFunc
(
rois_nums
*
pooled_height
*
pooled_width
,
&
k_dim
,
&
k_type
);
auto
memory_format
=
auto
memory_format
=
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
top_grad
.
dim
());
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
top_grad
.
dim
());
auto
top_grad_
tensor
=
auto
top_grad_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
top_grad
,
memory_format
);
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
top_grad
,
memory_format
);
at
::
Tensor
bottom_grad_tmp
=
at
::
empty
({
batch
,
channel
,
height
,
width
},
auto
rois_contiguous
=
top_grad
.
options
(),
memory_format
)
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
rois
,
rois
.
suggest_memory_format
());
.
zero_
();
auto
bottom_grad_
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
bottom_grad
,
memory_format
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// get ptr of tensors
// get ptr of tensors
auto
bottom_grad_impl
=
torch_mlu
::
getMluTensorImpl
(
bottom_grad_tmp
);
auto
top_grad_impl
=
torch_mlu
::
getMluTensorImpl
(
top_grad_
);
auto
bottom_grad_ptr
=
bottom_grad_impl
->
cnnlMalloc
();
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois
);
auto
rois_ptr
=
rois_impl
->
cnnlMalloc
();
auto
top_grad_impl
=
torch_mlu
::
getMluTensorImpl
(
top_grad_tensor
);
auto
top_grad_ptr
=
top_grad_impl
->
cnnlMalloc
();
auto
top_grad_ptr
=
top_grad_impl
->
cnnlMalloc
();
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois_contiguous
);
auto
rois_ptr
=
rois_impl
->
cnnlMalloc
();
auto
bottom_grad_impl
=
torch_mlu
::
getMluTensorImpl
(
bottom_grad_
);
auto
bottom_grad_ptr
=
bottom_grad_impl
->
cnnlMalloc
();
KernelRoiAlignRotatedBackward
(
k_dim
,
k_type
,
queue
,
d_type
,
top_grad_ptr
,
MluOpTensorDescriptor
top_grad_desc
,
rois_desc
,
bottom_grad_desc
;
rois_ptr
,
bottom_grad_ptr
,
batch
,
height
,
width
,
top_grad_desc
.
set_with_layout
(
top_grad_
,
MLUOP_LAYOUT_NHWC
);
channel
,
rois_nums
,
roiAlignRotatedParams
);
rois_desc
.
set
(
rois_contiguous
);
bottom_grad
.
copy_
(
bottom_grad_tmp
);
bottom_grad_desc
.
set_with_layout
(
bottom_grad_
,
MLUOP_LAYOUT_NHWC
);
// get compute handle
auto
handle
=
mluOpGetCurrentHandle
();
TORCH_MLUOP_CHECK
(
mluOpRoiAlignRotatedBackward
(
handle
,
top_grad_desc
.
desc
(),
top_grad_ptr
,
rois_desc
.
desc
(),
rois_ptr
,
pooled_height
,
pooled_width
,
sampling_ratio
,
spatial_scale
,
aligned
,
clockwise
,
bottom_grad_desc
.
desc
(),
bottom_grad_ptr
));
bottom_grad
.
copy_
(
bottom_grad_
);
}
}
void
roi_align_rotated_forward_mlu
(
Tensor
input
,
Tensor
rois
,
Tensor
output
,
void
roi_align_rotated_forward_mlu
(
Tensor
input
,
Tensor
rois
,
Tensor
output
,
...
...
mmcv/ops/csrc/pytorch/mlu/roiaware_pool3d_mlu.cpp
View file @
91da9643
...
@@ -9,49 +9,7 @@
...
@@ -9,49 +9,7 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "mlu_common_helper.h"
#include "pytorch_mlu_helper.hpp"
void
KernelPtsIdxOfVoxels
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
int
pool_method
,
const
int
boxes_num
,
const
int
pts_num
,
const
int
max_pts_each_voxel
,
const
int
out_x
,
const
int
out_y
,
const
int
out_z
,
const
void
*
rois
,
const
void
*
pts
,
int
*
pts_idx_of_voxels
);
void
KernelRoiawarePool3dForward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
int
pool_method
,
const
int
boxes_num
,
const
int
pts_num
,
const
int
channels
,
const
int
max_pts_each_voxel
,
const
int
out_x
,
const
int
out_y
,
const
int
out_z
,
const
void
*
pts_feature
,
const
int
*
pts_idx_of_voxels
,
void
*
pooled_features
,
int
*
argmax
);
// policy function
static
void
kernelPtsIdxOfVoxelsPolicyFunc
(
const
int
boxes_num
,
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
)
{
unsigned
int
core_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
unsigned
int
cluster_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
k_dim
->
x
=
core_num
;
unsigned
int
use_cluster
=
(
boxes_num
+
core_num
-
1
)
/
core_num
;
k_dim
->
y
=
use_cluster
>
cluster_num
?
cluster_num
:
use_cluster
;
k_dim
->
z
=
1
;
}
static
void
kernelRoiawarePool3dForwardPolicyFunc
(
const
int
boxes_num
,
const
int
out_x
,
const
int
out_y
,
const
int
out_z
,
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
)
{
unsigned
int
core_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
unsigned
int
cluster_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
k_dim
->
x
=
core_num
;
const
int
voxels_num
=
boxes_num
*
out_x
*
out_y
*
out_z
;
unsigned
int
use_cluster
=
(
voxels_num
+
core_num
-
1
)
/
core_num
;
k_dim
->
y
=
use_cluster
>
cluster_num
?
cluster_num
:
use_cluster
;
k_dim
->
z
=
1
;
}
void
RoiawarePool3dForwardMLUKernelLauncher
(
void
RoiawarePool3dForwardMLUKernelLauncher
(
const
int
pool_method
,
const
int
boxes_num
,
const
int
pts_num
,
const
int
pool_method
,
const
int
boxes_num
,
const
int
pts_num
,
...
@@ -59,168 +17,65 @@ void RoiawarePool3dForwardMLUKernelLauncher(
...
@@ -59,168 +17,65 @@ void RoiawarePool3dForwardMLUKernelLauncher(
const
int
out_y
,
const
int
out_z
,
const
Tensor
rois
,
const
Tensor
pts
,
const
int
out_y
,
const
int
out_z
,
const
Tensor
rois
,
const
Tensor
pts
,
const
Tensor
pts_feature
,
Tensor
pts_idx_of_voxels
,
Tensor
pooled_features
,
const
Tensor
pts_feature
,
Tensor
pts_idx_of_voxels
,
Tensor
pooled_features
,
Tensor
argmax
)
{
Tensor
argmax
)
{
// check datatype
// get compute handle
TORCH_CHECK
(((
pts
.
scalar_type
()
==
rois
.
scalar_type
())
&&
auto
handle
=
mluOpGetCurrentHandle
();
(
pts_feature
.
scalar_type
()
==
rois
.
scalar_type
())
&&
(
pooled_features
.
scalar_type
()
==
rois
.
scalar_type
())),
auto
rois_contiguous
=
"data types of rois, rois, pts_feature and pooled_features "
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
rois
,
rois
.
suggest_memory_format
());
"should be the same, "
,
auto
pts_contiguous
=
"but now rois type is "
,
rois
.
scalar_type
(),
", pts type is "
,
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
pts
,
pts
.
suggest_memory_format
());
pts
.
scalar_type
(),
", pts_feature type is "
,
auto
pts_feature_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
pts_feature
.
scalar_type
(),
", pooled_features type is "
,
pts_feature
,
pts_feature
.
suggest_memory_format
());
pooled_features
.
scalar_type
(),
"."
);
auto
argmax_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
TORCH_CHECK
(
argmax
,
argmax
.
suggest_memory_format
());
(
rois
.
scalar_type
()
==
at
::
kFloat
||
rois
.
scalar_type
()
==
at
::
kHalf
),
auto
pts_idx_of_voxels_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
"rois type should be Float or Half, got "
,
rois
.
scalar_type
(),
"."
);
pts_idx_of_voxels
,
pts_idx_of_voxels
.
suggest_memory_format
());
TORCH_CHECK
((
pts_idx_of_voxels
.
scalar_type
()
==
at
::
kInt
),
auto
pooled_features_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
"pts_idx_of_voxels type should be Int, got "
,
pooled_features
,
pooled_features
.
suggest_memory_format
());
pts_idx_of_voxels
.
scalar_type
(),
"."
);
// check dim
MluOpTensorDescriptor
rois_desc
,
pts_desc
,
pts_feature_desc
,
argmax_desc
,
TORCH_CHECK
(
rois
.
dim
()
==
2
,
"rois should be a 2D tensor, got "
,
rois
.
dim
(),
pts_idx_of_voxels_desc
,
pooled_features_desc
;
"D."
);
rois_desc
.
set
(
rois_contiguous
);
TORCH_CHECK
(
pts
.
dim
()
==
2
,
"pts should be a 2D tensor, got "
,
pts
.
dim
(),
pts_desc
.
set
(
pts_contiguous
);
"D."
);
pts_feature_desc
.
set
(
pts_feature_contiguous
);
TORCH_CHECK
(
pts_feature
.
dim
()
==
2
,
"pts_feature should be a 2D tensor, got "
,
argmax_desc
.
set
(
argmax_contiguous
);
pts_feature
.
dim
(),
"D."
);
pts_idx_of_voxels_desc
.
set
(
pts_idx_of_voxels_contiguous
);
TORCH_CHECK
(
pts_idx_of_voxels
.
dim
()
==
5
,
pooled_features_desc
.
set
(
pooled_features_contiguous
);
"pts_idx_of_voxels should be a 5D tensor, got "
,
pts_idx_of_voxels
.
dim
(),
"D."
);
// allocate extra space for workspace
TORCH_CHECK
(
pooled_features
.
dim
()
==
5
,
size_t
workspace_size
=
0
;
"pooled_features should be a 5D tensor, got "
,
TORCH_MLUOP_CHECK
(
mluOpGetRoiawarePool3dForwardWorkspaceSize
(
pooled_features
.
dim
(),
"D."
);
handle
,
rois_desc
.
desc
(),
pts_desc
.
desc
(),
pts_feature_desc
.
desc
(),
// check shape
&
workspace_size
));
TORCH_CHECK
(((
rois
.
size
(
0
)
==
boxes_num
)
&&
(
rois
.
size
(
1
)
==
7
)),
"the dimensions of rois should be (boxes_num, 7), "
,
"but got ("
,
auto
workspace
=
at
::
empty
(
workspace_size
,
rois
.
options
().
dtype
(
at
::
kByte
));
rois
.
size
(
0
),
", "
,
rois
.
size
(
1
),
") ."
);
auto
workspace_impl
=
torch_mlu
::
getMluTensorImpl
(
workspace
);
TORCH_CHECK
(((
pts
.
size
(
0
)
==
pts_num
)
&&
(
pts
.
size
(
1
)
==
3
)),
auto
workspace_ptr
=
workspace_impl
->
cnnlMalloc
();
"the dimensions of pts should be (pts_num, 3), "
,
"but got ("
,
pts
.
size
(
0
),
","
,
pts
.
size
(
1
),
")."
);
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois_contiguous
);
TORCH_CHECK
(
auto
pts_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_contiguous
);
((
pts_feature
.
size
(
0
)
==
pts_num
)
&&
(
pts_feature
.
size
(
1
)
==
channels
)),
auto
pts_feature_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_feature_contiguous
);
"the dimensions of pts_feature should be (pts_num, channels), "
,
auto
argmax_impl
=
torch_mlu
::
getMluTensorImpl
(
argmax_contiguous
);
"but got ("
,
pts_feature
.
size
(
0
),
","
,
pts_feature
.
size
(
1
),
")."
);
auto
pts_idx_of_voxels_impl
=
TORCH_CHECK
(((
pts_idx_of_voxels
.
size
(
0
)
==
boxes_num
)
&&
torch_mlu
::
getMluTensorImpl
(
pts_idx_of_voxels_contiguous
);
(
pts_idx_of_voxels
.
size
(
1
)
==
out_x
)
&&
auto
pooled_features_impl
=
(
pts_idx_of_voxels
.
size
(
2
)
==
out_y
)
&&
torch_mlu
::
getMluTensorImpl
(
pooled_features_contiguous
);
(
pts_idx_of_voxels
.
size
(
3
)
==
out_z
)
&&
(
pts_idx_of_voxels
.
size
(
4
)
==
max_pts_each_voxel
)),
"the dimensions of pts_idx_of_voxels should be (boxes_num, "
"out_x, out_y, out_z, max_pts_each_voxel), "
,
"but got ("
,
pts_idx_of_voxels
.
size
(
0
),
","
,
pts_idx_of_voxels
.
size
(
1
),
","
,
pts_idx_of_voxels
.
size
(
2
),
","
,
pts_idx_of_voxels
.
size
(
3
),
","
,
pts_idx_of_voxels
.
size
(
4
),
")."
);
TORCH_CHECK
(((
pooled_features
.
size
(
0
)
==
boxes_num
)
&&
(
pooled_features
.
size
(
1
)
==
out_x
)
&&
(
pooled_features
.
size
(
2
)
==
out_y
)
&&
(
pooled_features
.
size
(
3
)
==
out_z
)
&&
(
pooled_features
.
size
(
4
)
==
channels
)),
"the dimensions of pooled_features should be (boxes_num, out_x, "
"out_y, out_z, channels), "
,
"but got ("
,
pooled_features
.
size
(
0
),
","
,
pooled_features
.
size
(
1
),
","
,
pooled_features
.
size
(
2
),
","
,
pooled_features
.
size
(
3
),
","
,
pooled_features
.
size
(
4
),
")."
);
// check other params : pool_mothod
TORCH_CHECK
(((
pool_method
==
0
)
||
(
pool_method
==
1
)),
"the num of pool_method should be 0(max) or 1(avg), "
,
"but got "
,
pool_method
,
"."
);
// check large tensor
const
size_t
max_input_size
=
2147483648
;
TORCH_CHECK
(
rois
.
numel
()
<
max_input_size
,
"rois element num should be less than 2^31, got "
,
rois
.
numel
(),
"."
);
TORCH_CHECK
(
pts
.
numel
()
<
max_input_size
,
"pts element num should be less than 2^31, got "
,
pts
.
numel
(),
"."
);
TORCH_CHECK
(
pts_feature
.
numel
()
<
max_input_size
,
"pts_feature element num should be less than 2^31, got "
,
pts_feature
.
numel
(),
"."
);
TORCH_CHECK
(
pts_idx_of_voxels
.
numel
()
<
max_input_size
,
"pts_idx_of_voxels element num should be less than 2^31, got "
,
pts_idx_of_voxels
.
numel
(),
"."
);
TORCH_CHECK
(
pooled_features
.
numel
()
<
max_input_size
,
"pooled_features element num should be less than 2^31, got "
,
pooled_features
.
numel
(),
"."
);
// check zero element
TORCH_CHECK
(
rois
.
numel
()
!=
0
,
"rois.numel() should not be zero, got "
,
rois
.
numel
());
TORCH_CHECK
(
pts
.
numel
()
!=
0
,
"pts.numel() should not be zero, got "
,
pts
.
numel
());
TORCH_CHECK
(
pts_feature
.
numel
()
!=
0
,
"pts_feature.numel() should not be zero, got "
,
pts_feature
.
numel
());
TORCH_CHECK
(
pts_idx_of_voxels
.
numel
()
!=
0
,
"pts_idx_of_voxels.numel() should not be zero, got "
,
pts_idx_of_voxels
.
numel
());
TORCH_CHECK
(
pooled_features
.
numel
()
!=
0
,
"pooled_features.numel() should not be zero, got "
,
pooled_features
.
numel
());
if
(
pool_method
==
0
)
{
// check datatype
TORCH_CHECK
((
argmax
.
scalar_type
()
==
at
::
kInt
),
"argmax type should be Int, got "
,
argmax
.
scalar_type
(),
"."
);
// check dim
TORCH_CHECK
(
argmax
.
dim
()
==
5
,
"argmax should be a 5D tensor, got "
,
argmax
.
dim
(),
"D."
);
// check shape
TORCH_CHECK
(((
argmax
.
size
(
0
)
==
boxes_num
)
&&
(
argmax
.
size
(
1
)
==
out_x
)
&&
(
argmax
.
size
(
2
)
==
out_y
)
&&
(
argmax
.
size
(
3
)
==
out_z
)
&&
(
argmax
.
size
(
4
)
==
channels
)),
"the dimensions of argmax should be (boxes_num, out_x, out_y, "
"out_z, channels), "
,
"but got ("
,
argmax
.
size
(
0
),
","
,
argmax
.
size
(
1
),
","
,
argmax
.
size
(
2
),
","
,
argmax
.
size
(
3
),
","
,
argmax
.
size
(
4
),
")."
);
// check large tensor
TORCH_CHECK
(
argmax
.
numel
()
<
max_input_size
,
"argmax element num should be less than 2^31, got "
,
argmax
.
numel
(),
"."
);
// check zero element
TORCH_CHECK
(
argmax
.
numel
()
!=
0
,
"argmax.numel() should not be zero, got "
,
argmax
.
numel
());
// when pool_method is 0, which is max pool, init argmax data value to -1
argmax
.
fill_
(
static_cast
<
int
>
(
-
1
));
}
// calculate task one dimension
cnrtDim3_t
k1_dim
;
cnrtFunctionType_t
k1_type
;
kernelPtsIdxOfVoxelsPolicyFunc
(
boxes_num
,
&
k1_dim
,
&
k1_type
);
cnrtDim3_t
k2_dim
;
cnrtFunctionType_t
k2_type
;
kernelRoiawarePool3dForwardPolicyFunc
(
boxes_num
,
out_x
,
out_y
,
out_z
,
&
k2_dim
,
&
k2_type
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// get ptr of tensors
auto
rois_impl
=
torch_mlu
::
getMluTensorImpl
(
rois
);
auto
rois_ptr
=
rois_impl
->
cnnlMalloc
();
auto
rois_ptr
=
rois_impl
->
cnnlMalloc
();
// transpose points [pts_num, 3] -> [3, pts_num]
auto
pts_
=
pts
.
permute
({
1
,
0
}).
contiguous
();
auto
pts_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_
);
auto
pts_ptr
=
pts_impl
->
cnnlMalloc
();
auto
pts_ptr
=
pts_impl
->
cnnlMalloc
();
// transpose points_features [pts_num, channels] -> [channels, pts_num]
auto
pts_feature_
=
pts_feature
.
permute
({
1
,
0
}).
contiguous
();
auto
pts_feature_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_feature_
);
auto
pts_feature_ptr
=
pts_feature_impl
->
cnnlMalloc
();
auto
pts_feature_ptr
=
pts_feature_impl
->
cnnlMalloc
();
auto
pts_idx_of_voxels_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_idx_of_voxels
);
auto
argmax_ptr
=
argmax_impl
->
cnnlMalloc
(
);
auto
pts_idx_of_voxels_ptr
=
pts_idx_of_voxels_impl
->
cnnlMalloc
();
auto
pts_idx_of_voxels_ptr
=
pts_idx_of_voxels_impl
->
cnnlMalloc
();
auto
pooled_features_impl
=
torch_mlu
::
getMluTensorImpl
(
pooled_features
);
auto
pooled_features_ptr
=
pooled_features_impl
->
cnnlMalloc
();
auto
pooled_features_ptr
=
pooled_features_impl
->
cnnlMalloc
();
auto
argmax_impl
=
torch_mlu
::
getMluTensorImpl
(
argmax
);
auto
argmax_ptr
=
argmax_impl
->
cnnlMalloc
();
CNLOG
(
INFO
)
<<
"Call mluOpRoiawarePool3dForward()."
;
// get compute dtype of input
TORCH_MLUOP_CHECK
(
mluOpRoiawarePool3dForward
(
cnrtDataType_t
data_type
=
torch_mlu
::
toCnrtDtype
(
rois
.
dtype
());
handle
,
pool_method
,
boxes_num
,
pts_num
,
channels
,
rois_desc
.
desc
(),
// launch kernel PtsIdxOfVoxels
rois_ptr
,
pts_desc
.
desc
(),
pts_ptr
,
pts_feature_desc
.
desc
(),
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernel PtsIdxOfVoxels<<<"
<<
k1_dim
.
x
<<
", "
pts_feature_ptr
,
workspace_ptr
,
workspace_size
,
max_pts_each_voxel
,
out_x
,
<<
k1_dim
.
y
<<
", "
<<
k1_dim
.
z
<<
">>>"
;
out_y
,
out_z
,
argmax_desc
.
desc
(),
argmax_ptr
,
KernelPtsIdxOfVoxels
(
k1_dim
,
k1_type
,
queue
,
data_type
,
pool_method
,
pts_idx_of_voxels_desc
.
desc
(),
pts_idx_of_voxels_ptr
,
boxes_num
,
pts_num
,
max_pts_each_voxel
,
out_x
,
out_y
,
pooled_features_desc
.
desc
(),
pooled_features_ptr
));
out_z
,
rois_ptr
,
pts_ptr
,
(
int
*
)
pts_idx_of_voxels_ptr
);
// launch kernel RoiawarePool3dForward
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernel RoiawarePool3dForward<<<"
<<
k2_dim
.
x
<<
", "
<<
k2_dim
.
y
<<
", "
<<
k2_dim
.
z
<<
">>>"
;
KernelRoiawarePool3dForward
(
k2_dim
,
k2_type
,
queue
,
data_type
,
pool_method
,
boxes_num
,
pts_num
,
channels
,
max_pts_each_voxel
,
out_x
,
out_y
,
out_z
,
pts_feature_ptr
,
(
int
*
)
pts_idx_of_voxels_ptr
,
pooled_features_ptr
,
(
int
*
)
argmax_ptr
);
}
}
void
roiaware_pool3d_forward_mlu
(
int
boxes_num
,
int
pts_num
,
int
channels
,
void
roiaware_pool3d_forward_mlu
(
int
boxes_num
,
int
pts_num
,
int
channels
,
...
@@ -245,136 +100,46 @@ void roiaware_pool3d_forward_impl(int boxes_num, int pts_num, int channels,
...
@@ -245,136 +100,46 @@ void roiaware_pool3d_forward_impl(int boxes_num, int pts_num, int channels,
REGISTER_DEVICE_IMPL
(
roiaware_pool3d_forward_impl
,
MLU
,
REGISTER_DEVICE_IMPL
(
roiaware_pool3d_forward_impl
,
MLU
,
roiaware_pool3d_forward_mlu
);
roiaware_pool3d_forward_mlu
);
void
KernelRoiawarePool3dBackward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
cnrtDataType_t
d_type
,
const
int
pool_method
,
const
int
boxes_num
,
const
int
out_x
,
const
int
out_y
,
const
int
out_z
,
const
int
channels
,
const
int
max_pts_each_voxel
,
const
int
*
pts_idx_of_voxels
,
const
int
*
argmax
,
const
void
*
grad_out
,
void
*
grad_in
);
static
void
kernelRoiawarePool3dBackwardPolicyFunc
(
const
int
boxes_num
,
const
int
out_x
,
const
int
out_y
,
const
int
out_z
,
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
)
{
unsigned
int
core_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
unsigned
int
cluster_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
k_dim
->
x
=
core_num
;
const
int
voxels_num
=
boxes_num
*
out_x
*
out_y
*
out_z
;
unsigned
int
use_cluster
=
(
voxels_num
+
core_num
-
1
)
/
core_num
;
k_dim
->
y
=
use_cluster
>
cluster_num
?
cluster_num
:
use_cluster
;
k_dim
->
z
=
1
;
}
void
RoiawarePool3dBackwardMLUKernelLauncher
(
void
RoiawarePool3dBackwardMLUKernelLauncher
(
int
pool_method
,
int
boxes_num
,
int
out_x
,
int
out_y
,
int
out_z
,
int
pool_method
,
int
boxes_num
,
int
out_x
,
int
out_y
,
int
out_z
,
int
channels
,
int
max_pts_each_voxel
,
const
Tensor
pts_idx_of_voxels
,
int
channels
,
int
max_pts_each_voxel
,
const
Tensor
pts_idx_of_voxels
,
const
Tensor
argmax
,
const
Tensor
grad_out
,
Tensor
grad_in
)
{
const
Tensor
argmax
,
const
Tensor
grad_out
,
Tensor
grad_in
)
{
// check datatype
// get compute handle
TORCH_CHECK
((
pts_idx_of_voxels
.
scalar_type
()
==
at
::
kInt
),
auto
handle
=
mluOpGetCurrentHandle
();
"pts_idx_of_voxels type should be Int, got "
,
auto
pts_idx_of_voxels_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
pts_idx_of_voxels
.
scalar_type
(),
"."
);
pts_idx_of_voxels
,
pts_idx_of_voxels
.
suggest_memory_format
());
TORCH_CHECK
((
argmax
.
scalar_type
()
==
at
::
kInt
),
auto
argmax_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
"argmax type should be Int, got "
,
argmax
.
scalar_type
(),
"."
);
argmax
,
argmax
.
suggest_memory_format
());
TORCH_CHECK
((
grad_out
.
scalar_type
()
==
at
::
kFloat
||
auto
grad_out_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
grad_out
.
scalar_type
()
==
at
::
kHalf
),
grad_out
,
grad_out
.
suggest_memory_format
());
"grad_out type should be Float or Half, got "
,
auto
grad_in_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
grad_out
.
scalar_type
(),
"."
);
grad_in
,
grad_in
.
suggest_memory_format
());
TORCH_CHECK
((
grad_out
.
scalar_type
()
==
grad_in
.
scalar_type
()),
"data types of grad_out, grad_in, should be the same, "
,
MluOpTensorDescriptor
pts_idx_of_voxels_desc
,
argmax_desc
,
grad_out_desc
,
"but now grad_out type is "
,
grad_out
.
scalar_type
(),
grad_in_desc
;
", grad_in type is "
,
grad_in
.
scalar_type
(),
"."
);
// check dim
pts_idx_of_voxels_desc
.
set
(
pts_idx_of_voxels_contiguous
);
TORCH_CHECK
(
pts_idx_of_voxels
.
dim
()
==
5
,
argmax_desc
.
set
(
argmax_contiguous
);
"pts_idx_of_voxels should be a 5D tensor, got "
,
grad_out_desc
.
set
(
grad_out_contiguous
);
pts_idx_of_voxels
.
dim
(),
"D."
);
grad_in_desc
.
set
(
grad_in_contiguous
);
TORCH_CHECK
(
argmax
.
dim
()
==
5
,
"argmax should be a 5D tensor, got "
,
argmax
.
dim
(),
"D."
);
auto
pts_idx_of_voxels_impl
=
TORCH_CHECK
(
grad_out
.
dim
()
==
5
,
"grad_out should be a 5D tensor, got "
,
torch_mlu
::
getMluTensorImpl
(
pts_idx_of_voxels_contiguous
);
grad_out
.
dim
(),
"D."
);
auto
argmax_impl
=
torch_mlu
::
getMluTensorImpl
(
argmax_contiguous
);
TORCH_CHECK
(
grad_in
.
dim
()
==
2
,
"grad_in should be a 2D tensor, got "
,
auto
grad_out_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_out_contiguous
);
grad_in
.
dim
(),
"D."
);
auto
grad_in_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_in_contiguous
);
// check shape
TORCH_CHECK
(((
pts_idx_of_voxels
.
size
(
0
)
==
boxes_num
)
&&
(
pts_idx_of_voxels
.
size
(
1
)
==
out_x
)
&&
(
pts_idx_of_voxels
.
size
(
2
)
==
out_y
)
&&
(
pts_idx_of_voxels
.
size
(
3
)
==
out_z
)
&&
(
pts_idx_of_voxels
.
size
(
4
)
==
max_pts_each_voxel
)),
"the dimensions of pts_idx_of_voxels should be (boxes_num, "
"out_x, out_y, out_z, max_pts_each_voxel), "
,
"but got ("
,
pts_idx_of_voxels
.
size
(
0
),
","
,
pts_idx_of_voxels
.
size
(
1
),
","
,
pts_idx_of_voxels
.
size
(
2
),
","
,
pts_idx_of_voxels
.
size
(
3
),
","
,
pts_idx_of_voxels
.
size
(
4
),
")."
);
TORCH_CHECK
(((
argmax
.
size
(
0
)
==
boxes_num
)
&&
(
argmax
.
size
(
1
)
==
out_x
)
&&
(
argmax
.
size
(
2
)
==
out_y
)
&&
(
argmax
.
size
(
3
)
==
out_z
)
&&
(
argmax
.
size
(
4
)
==
channels
)),
"the dimensions of argmax should be (boxes_num, out_x, out_y, "
"out_z, channels), "
,
"but got ("
,
argmax
.
size
(
0
),
","
,
argmax
.
size
(
1
),
","
,
argmax
.
size
(
2
),
","
,
argmax
.
size
(
3
),
","
,
argmax
.
size
(
4
),
")."
);
TORCH_CHECK
(((
grad_out
.
size
(
0
)
==
boxes_num
)
&&
(
grad_out
.
size
(
1
)
==
out_x
)
&&
(
grad_out
.
size
(
2
)
==
out_y
)
&&
(
grad_out
.
size
(
3
)
==
out_z
)
&&
(
grad_out
.
size
(
4
)
==
channels
)),
"the dimensions of grad_out should be (boxes_num, out_x, "
"out_y, out_z, channels), "
,
"but got ("
,
grad_out
.
size
(
0
),
","
,
grad_out
.
size
(
1
),
","
,
grad_out
.
size
(
2
),
","
,
grad_out
.
size
(
3
),
","
,
grad_out
.
size
(
4
),
")."
);
TORCH_CHECK
((
grad_in
.
size
(
1
)
==
channels
),
"the 1st dimensions of grad_in should be channels, "
,
"but got "
,
grad_in
.
size
(
1
),
"."
);
// check other params : pool_mothod
TORCH_CHECK
(((
pool_method
==
0
)
||
(
pool_method
==
1
)),
"the num of pool_method should be 0(max) or 1(avg), "
,
"but got "
,
pool_method
,
"."
);
// check large tensor
const
size_t
max_input_size
=
2147483648
;
TORCH_CHECK
(
pts_idx_of_voxels
.
numel
()
<
max_input_size
,
"pts_idx_of_voxels element num should be less than 2^31, got "
,
pts_idx_of_voxels
.
numel
(),
"."
);
TORCH_CHECK
(
argmax
.
numel
()
<
max_input_size
,
"argmax element num should be less than 2^31, got "
,
argmax
.
numel
(),
"."
);
TORCH_CHECK
(
grad_out
.
numel
()
<
max_input_size
,
"grad_out element num should be less than 2^31, got "
,
grad_out
.
numel
(),
"."
);
TORCH_CHECK
(
grad_in
.
numel
()
<
max_input_size
,
"grad_in element num should be less than 2^31, got "
,
grad_in
.
numel
(),
"."
);
// check zero element
TORCH_CHECK
(
pts_idx_of_voxels
.
numel
()
!=
0
,
"pts_idx_of_voxels.numel() should not be zero, got "
,
pts_idx_of_voxels
.
numel
());
TORCH_CHECK
(
argmax
.
numel
()
!=
0
,
"argmax.numel() should not be zero, got "
,
argmax
.
numel
());
TORCH_CHECK
(
grad_out
.
numel
()
!=
0
,
"grad_out.numel() should not be zero, got "
,
grad_out
.
numel
());
TORCH_CHECK
(
grad_in
.
numel
()
!=
0
,
"grad_in.numel() should not be zero, got "
,
grad_in
.
numel
());
// calculate task one dimension
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
kernelRoiawarePool3dBackwardPolicyFunc
(
boxes_num
,
out_x
,
out_y
,
out_z
,
&
k_dim
,
&
k_type
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// transpose points_features [pts_num, channels] -> [channels, pts_num]
auto
pts_idx_of_voxels_impl
=
torch_mlu
::
getMluTensorImpl
(
pts_idx_of_voxels
);
auto
pts_idx_of_voxels_ptr
=
pts_idx_of_voxels_impl
->
cnnlMalloc
();
auto
pts_idx_of_voxels_ptr
=
pts_idx_of_voxels_impl
->
cnnlMalloc
();
auto
argmax_impl
=
torch_mlu
::
getMluTensorImpl
(
argmax
);
auto
argmax_ptr
=
argmax_impl
->
cnnlMalloc
();
auto
argmax_ptr
=
argmax_impl
->
cnnlMalloc
();
auto
grad_out_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_out
);
auto
grad_out_ptr
=
grad_out_impl
->
cnnlMalloc
();
auto
grad_out_ptr
=
grad_out_impl
->
cnnlMalloc
();
auto
grad_in_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_in
);
auto
grad_in_ptr
=
grad_in_impl
->
cnnlMalloc
();
auto
grad_in_ptr
=
grad_in_impl
->
cnnlMalloc
();
// get compute dtype of input
cnrtDataType_t
data_type
=
torch_mlu
::
toCnrtDtype
(
grad_out
.
dtype
());
CNLOG
(
INFO
)
<<
"Call mluOpRoiawarePool3dBackward()."
;
// launch kernel RoiawarePool3dForward
TORCH_MLUOP_CHECK
(
mluOpRoiawarePool3dBackward
(
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernel RoiawarePool3dBackward<<<"
<<
k_dim
.
x
handle
,
pool_method
,
boxes_num
,
out_x
,
out_y
,
out_z
,
channels
,
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>"
;
max_pts_each_voxel
,
pts_idx_of_voxels_desc
.
desc
(),
pts_idx_of_voxels_ptr
,
KernelRoiawarePool3dBackward
(
k_dim
,
k_type
,
queue
,
data_type
,
pool_method
,
argmax_desc
.
desc
(),
argmax_ptr
,
grad_out_desc
.
desc
(),
grad_out_ptr
,
boxes_num
,
out_x
,
out_y
,
out_z
,
channels
,
grad_in_desc
.
desc
(),
grad_in_ptr
));
max_pts_each_voxel
,
(
int
*
)
pts_idx_of_voxels_ptr
,
(
int
*
)
argmax_ptr
,
grad_out_ptr
,
grad_in_ptr
);
}
}
void
roiaware_pool3d_backward_mlu
(
int
boxes_num
,
int
out_x
,
int
out_y
,
void
roiaware_pool3d_backward_mlu
(
int
boxes_num
,
int
out_x
,
int
out_y
,
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment