Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
a55f4b7f
You need to sign in or sign up before continuing.
Unverified
Commit
a55f4b7f
authored
Apr 03, 2023
by
liuduanhui
Committed by
GitHub
Apr 03, 2023
Browse files
[Enhancement] Replace the implementation of three_nn_forward with mlu-ops (#2719)
parent
f946a933
Changes
2
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
536 deletions
+33
-536
mmcv/ops/csrc/common/mlu/three_nn_mlu_kernel.mlu
mmcv/ops/csrc/common/mlu/three_nn_mlu_kernel.mlu
+0
-466
mmcv/ops/csrc/pytorch/mlu/three_nn_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/three_nn_mlu.cpp
+33
-70
No files found.
mmcv/ops/csrc/common/mlu/three_nn_mlu_kernel.mlu
deleted
100644 → 0
View file @
f946a933
This diff is collapsed.
Click to expand it.
mmcv/ops/csrc/pytorch/mlu/three_nn_mlu.cpp
View file @
a55f4b7f
...
@@ -9,84 +9,47 @@
...
@@ -9,84 +9,47 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "mlu_common_helper.h"
#include "pytorch_mlu_helper.hpp"
void
KernelThreeNNForward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
cnrtDataType_t
data_type
,
const
void
*
unknown
,
const
void
*
known
,
void
*
dist2
,
int
*
idx
,
const
int
b
,
const
int
n
,
const
int
m
);
void
ThreeNNMLUKernelLauncher
(
int
b
,
int
n
,
int
m
,
const
Tensor
unknown
,
void
ThreeNNMLUKernelLauncher
(
int
b
,
int
n
,
int
m
,
const
Tensor
unknown
,
const
Tensor
known
,
Tensor
dist2
,
Tensor
idx
)
{
const
Tensor
known
,
Tensor
dist2
,
Tensor
idx
)
{
// Check dtype.
auto
unknown_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
TORCH_CHECK
(
unknown
,
unknown
.
suggest_memory_format
());
unknown
.
scalar_type
()
==
at
::
kFloat
||
unknown
.
scalar_type
()
==
at
::
kHalf
,
auto
known_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
"unknown type should be Float or Half, got "
,
unknown
.
scalar_type
(),
"."
);
known
,
known
.
suggest_memory_format
());
TORCH_CHECK
(
unknown
.
scalar_type
()
==
known
.
scalar_type
(),
auto
dist2_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
"known should have the same type as unknown."
);
dist2
,
dist2
.
suggest_memory_format
());
TORCH_CHECK
(
unknown
.
scalar_type
()
==
dist2
.
scalar_type
(),
auto
idx_contiguous
=
"dist2 should have the same type as unknown."
);
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
idx
,
idx
.
suggest_memory_format
());
TORCH_CHECK
(
idx
.
scalar_type
()
==
at
::
kInt
,
"idx type should be Int."
);
MluOpTensorDescriptor
unknown_desc
,
known_desc
,
dist2_desc
,
idx_desc
;
// Check shape.
unknown_desc
.
set
(
unknown_contiguous
);
TORCH_CHECK
(
unknown
.
dim
()
==
3
,
"unknown should be 3d tensor, got "
,
known_desc
.
set
(
known_contiguous
);
unknown
.
dim
(),
"D."
);
dist2_desc
.
set
(
dist2_contiguous
);
TORCH_CHECK
(
known
.
dim
()
==
3
,
"known should be 3d tensor, got "
,
known
.
dim
(),
idx_desc
.
set
(
idx_contiguous
);
"D."
);
TORCH_CHECK
(
unknown
.
size
(
0
)
==
known
.
size
(
0
),
auto
handle
=
mluOpGetCurrentHandle
();
"known.dim0 should be equal to unknown.dim0, got "
,
known
.
size
(
0
),
size_t
workspace_size
=
0
;
"."
);
mluOpGetThreeNNForwardWorkspaceSize
(
handle
,
known_desc
.
desc
(),
TORCH_CHECK
(
unknown
.
size
(
2
)
==
3
,
"unknown dim2 should be 3, got "
,
&
workspace_size
);
unknown
.
size
(
2
),
"."
);
auto
known_workspace
=
TORCH_CHECK
(
known
.
size
(
2
)
==
3
,
"known dim2 should be 3, got "
,
known
.
size
(
2
),
at
::
empty
(
workspace_size
,
known
.
options
().
dtype
(
at
::
kByte
));
"."
);
auto
unknown_impl
=
torch_mlu
::
getMluTensorImpl
(
unknown_contiguous
);
// zero element check
auto
known_impl
=
torch_mlu
::
getMluTensorImpl
(
known_contiguous
);
TORCH_CHECK
(
unknown
.
numel
()
>
0
,
auto
dist2_impl
=
torch_mlu
::
getMluTensorImpl
(
dist2_contiguous
);
"unknown.numel should greater than zero, got "
,
unknown
.
numel
(),
auto
idx_impl
=
torch_mlu
::
getMluTensorImpl
(
idx_contiguous
);
"."
);
auto
workspace_impl
=
torch_mlu
::
getMluTensorImpl
(
known_workspace
);
if
(
known
.
numel
()
==
0
)
{
// return if known zero element
return
;
}
// large tensor check
const
size_t
max_input_num
=
2147483648
;
// 2^31, 2G num
TORCH_CHECK
(
unknown
.
numel
()
<
max_input_num
,
"unknown.numel() should be less than 2147483648, got "
,
unknown
.
numel
(),
"."
);
TORCH_CHECK
(
known
.
numel
()
<
max_input_num
,
"known.numel() should be less than 2147483648, got "
,
known
.
numel
(),
"."
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// get ptr of tensors
auto
unknown_impl
=
torch_mlu
::
getMluTensorImpl
(
unknown
);
auto
unknown_ptr
=
unknown_impl
->
cnnlMalloc
();
auto
unknown_ptr
=
unknown_impl
->
cnnlMalloc
();
auto
known_t
=
known
.
permute
({
0
,
2
,
1
}).
contiguous
();
auto
known_impl
=
torch_mlu
::
getMluTensorImpl
(
known_t
);
auto
known_ptr
=
known_impl
->
cnnlMalloc
();
auto
known_ptr
=
known_impl
->
cnnlMalloc
();
auto
dist2_impl
=
torch_mlu
::
getMluTensorImpl
(
dist2
);
auto
dist2_ptr
=
dist2_impl
->
cnnlMalloc
();
auto
dist2_ptr
=
dist2_impl
->
cnnlMalloc
();
auto
idx_impl
=
torch_mlu
::
getMluTensorImpl
(
idx
);
auto
idx_ptr
=
idx_impl
->
cnnlMalloc
();
auto
idx_ptr
=
idx_impl
->
cnnlMalloc
();
auto
workspace_ptr
=
workspace_impl
->
cnnlMalloc
();
cnrtJobType_t
k_type
=
CNRT_FUNC_TYPE_UNION1
;
mluOpThreeNNForward
(
handle
,
unknown_desc
.
desc
(),
unknown_ptr
,
cnrtDim3_t
k_dim
;
known_desc
.
desc
(),
known_ptr
,
workspace_ptr
,
k_dim
.
x
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
workspace_size
,
dist2_desc
.
desc
(),
dist2_ptr
,
k_dim
.
y
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
idx_desc
.
desc
(),
idx_ptr
);
k_dim
.
z
=
1
;
cnrtDataType_t
data_type
=
torch_mlu
::
toCnrtDtype
(
unknown
.
dtype
());
// launch kernel
CNLOG
(
INFO
)
<<
"Launch Kernel MLUKernelThreeNNForward<<<"
<<
k_dim
.
x
<<
", "
<<
k_dim
.
y
<<
", "
<<
k_dim
.
z
<<
">>>."
;
KernelThreeNNForward
(
k_dim
,
k_type
,
queue
,
data_type
,
unknown_ptr
,
known_ptr
,
dist2_ptr
,
(
int
*
)
idx_ptr
,
b
,
n
,
m
);
}
}
void
three_nn_forward_mlu
(
int
b
,
int
n
,
int
m
,
const
Tensor
unknown
,
void
three_nn_forward_mlu
(
int
b
,
int
n
,
int
m
,
const
Tensor
unknown
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment