Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
92b3e861
Unverified
Commit
92b3e861
authored
Jun 01, 2023
by
liuduanhui
Committed by
GitHub
Jun 01, 2023
Browse files
[Refactor] Replace the implementation of psa_mask with mlu-ops. (#2810)
parent
2611b990
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
883 deletions
+14
-883
mmcv/ops/csrc/common/mlu/psamask_mlu_kernel.mlu
mmcv/ops/csrc/common/mlu/psamask_mlu_kernel.mlu
+0
-615
mmcv/ops/csrc/common/mlu/psamask_utils.hpp
mmcv/ops/csrc/common/mlu/psamask_utils.hpp
+0
-55
mmcv/ops/csrc/pytorch/mlu/psamask_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/psamask_mlu.cpp
+14
-213
No files found.
mmcv/ops/csrc/common/mlu/psamask_mlu_kernel.mlu
deleted
100644 → 0
View file @
2611b990
This diff is collapsed.
Click to expand it.
mmcv/ops/csrc/common/mlu/psamask_utils.hpp
deleted
100644 → 0
View file @
2611b990
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef PSAMASK_UTILS_HPP_
#define PSAMASK_UTILS_HPP_
typedef
enum
{
COLLECT
=
0
,
DISTRIBUTE
=
1
,
}
PsamaskType
;
typedef
enum
{
PARTITION_N
=
0
,
PARTITION_H
=
1
,
}
DimPartitionType
;
struct
PartitionSeg
{
int
h_per_cluster
;
int
n_per_cluster
;
int
h_per_core
;
int
n_per_core
;
DimPartitionType
cluster_partition
;
DimPartitionType
core_partition
;
};
struct
Shape
{
int
n
;
int
h
;
int
w
;
int
c
;
};
struct
LimitParam
{
int
n
;
int
h
;
int
w
;
};
struct
PositionInCore
{
int
n_start
;
int
n_end
;
int
h_start
;
int
h_end
;
int
w_start
;
int
w_end
;
};
#endif // PSAMASK_UTILS_HPP_
mmcv/ops/csrc/pytorch/mlu/psamask_mlu.cpp
View file @
92b3e861
...
...
@@ -9,136 +9,7 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <algorithm>
#include "psamask_utils.hpp"
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
#define COMPUTE_COUNT_ALIGN 64
void
KernelPsamaskForward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
void
*
x
,
void
*
y
,
const
PsamaskType
psa_type
,
const
DimPartitionType
core_partition
,
const
DimPartitionType
cluster_partition
,
const
int
batch
,
const
int
h_feature
,
const
int
w_feature
,
const
int
h_mask
,
const
int
w_mask
,
const
int
x_c
,
const
int
y_c
,
const
int
half_h_mask
,
const
int
half_w_mask
,
const
int
n_per_core
,
const
int
h_per_core
,
const
int
n_per_cluster
,
const
int
h_per_cluster
,
const
int
limit_n_seg
,
const
int
limit_h_seg
,
const
int
limit_w_seg
);
void
KernelPsamaskBackward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
void
*
dy
,
void
*
dx
,
const
PsamaskType
psa_type
,
const
DimPartitionType
core_partition
,
const
DimPartitionType
cluster_partition
,
const
int
batch
,
const
int
h_feature
,
const
int
w_feature
,
const
int
h_mask
,
const
int
w_mask
,
const
int
dx_c
,
const
int
dy_c
,
const
int
half_h_mask
,
const
int
half_w_mask
,
const
int
n_per_core
,
const
int
h_per_core
,
const
int
n_per_cluster
,
const
int
h_per_cluster
,
const
int
limit_n_seg
,
const
int
limit_h_seg
,
const
int
limit_w_seg
);
namespace
{
void
policyFunc
(
cnrtDim3_t
*
k_dim_ptr
,
cnrtFunctionType_t
*
f_type_ptr
,
PartitionSeg
*
partition_ptr
,
const
int
n
,
const
int
h_feature
)
{
unsigned
int
core_dim
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
unsigned
int
cluster_num
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
unsigned
int
use_cluster_num
=
cluster_num
;
unsigned
int
use_core_num
=
core_dim
;
if
(
n
>=
cluster_num
||
n
>=
h_feature
)
{
partition_ptr
->
cluster_partition
=
PARTITION_N
;
partition_ptr
->
n_per_cluster
=
(
n
+
cluster_num
-
1
)
/
cluster_num
;
partition_ptr
->
h_per_cluster
=
h_feature
;
use_cluster_num
=
(
n
+
partition_ptr
->
n_per_cluster
-
1
)
/
partition_ptr
->
n_per_cluster
;
}
else
{
partition_ptr
->
cluster_partition
=
PARTITION_H
;
partition_ptr
->
h_per_cluster
=
(
h_feature
+
cluster_num
-
1
)
/
cluster_num
;
partition_ptr
->
n_per_cluster
=
n
;
use_cluster_num
=
(
h_feature
+
partition_ptr
->
h_per_cluster
-
1
)
/
partition_ptr
->
h_per_cluster
;
}
if
(
partition_ptr
->
n_per_cluster
>=
core_dim
||
partition_ptr
->
n_per_cluster
>=
partition_ptr
->
h_per_cluster
)
{
partition_ptr
->
core_partition
=
PARTITION_N
;
partition_ptr
->
n_per_core
=
(
partition_ptr
->
n_per_cluster
+
core_dim
-
1
)
/
core_dim
;
partition_ptr
->
h_per_core
=
partition_ptr
->
h_per_cluster
;
use_core_num
=
(
partition_ptr
->
n_per_cluster
+
partition_ptr
->
n_per_core
-
1
)
/
partition_ptr
->
n_per_core
;
}
else
{
partition_ptr
->
core_partition
=
PARTITION_H
;
partition_ptr
->
h_per_core
=
(
partition_ptr
->
h_per_cluster
+
core_dim
-
1
)
/
core_dim
;
partition_ptr
->
n_per_core
=
partition_ptr
->
n_per_cluster
;
use_core_num
=
(
partition_ptr
->
h_per_cluster
+
partition_ptr
->
h_per_core
-
1
)
/
partition_ptr
->
h_per_core
;
}
*
k_dim_ptr
=
{
core_dim
,
use_cluster_num
,
1
};
}
}
// namespace
bool
findLimit
(
const
int
shape_core_n
,
const
int
shape_core_h
,
const
int
shape_core_w
,
const
int
shape_core_ci
,
const
int
shape_core_co
,
int
*
limit_n_seg_ptr
,
int
*
limit_h_seg_ptr
,
int
*
limit_w_seg_ptr
,
const
int
psa_type
)
{
const
bool
need_temp
=
psa_type
==
1
;
const
int
input_bytes
=
sizeof
(
float
);
int
limit_n_seg
=
shape_core_n
;
int
limit_h_seg
=
shape_core_h
;
int
limit_w_seg
=
shape_core_w
;
const
int
max_nram_size
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrNramSizePerMcore
);
const
int
align_base_128
=
NFU_ALIGN_SIZE
/
input_bytes
;
const
int
align_base_64
=
COMPUTE_COUNT_ALIGN
/
input_bytes
;
const
int
align_co
=
CEIL_ALIGN
(
shape_core_co
,
align_base_64
);
const
int
align_w
=
CEIL_ALIGN
(
shape_core_w
,
align_base_64
);
const
int
align_hw
=
CEIL_ALIGN
(
shape_core_h
*
shape_core_w
,
align_base_64
);
const
int
max_num
=
max_nram_size
/
input_bytes
;
int
n_limit
=
max_num
/
(
CEIL_ALIGN
(
shape_core_h
*
shape_core_w
*
shape_core_ci
,
align_base_128
)
+
align_hw
*
align_co
*
(
1
+
need_temp
));
if
(
n_limit
>
0
)
{
n_limit
=
std
::
min
(
n_limit
,
shape_core_n
);
limit_n_seg
=
n_limit
;
}
else
{
int
h_limit
=
max_num
/
(
CEIL_ALIGN
(
shape_core_w
*
shape_core_ci
,
align_base_128
)
+
align_w
*
align_co
*
(
1
+
need_temp
));
if
(
h_limit
>
0
)
{
h_limit
=
std
::
min
(
h_limit
,
shape_core_h
);
limit_h_seg
=
h_limit
;
limit_n_seg
=
1
;
}
else
{
int
w_limit
=
max_num
/
(
CEIL_ALIGN
(
shape_core_ci
,
align_base_128
)
+
CEIL_ALIGN
(
align_co
,
align_base_128
)
*
(
1
+
need_temp
));
if
(
w_limit
>
0
&&
w_limit
>=
(
COMPUTE_COUNT_ALIGN
/
input_bytes
))
{
w_limit
=
std
::
min
(
w_limit
,
shape_core_w
);
w_limit
=
w_limit
/
(
COMPUTE_COUNT_ALIGN
/
input_bytes
)
*
(
COMPUTE_COUNT_ALIGN
/
input_bytes
);
limit_w_seg
=
w_limit
;
limit_h_seg
=
1
;
limit_n_seg
=
1
;
}
else
{
CNLOG
(
INFO
)
<<
"The size of input channel is too large."
;
return
false
;
}
}
}
*
limit_n_seg_ptr
=
limit_n_seg
;
*
limit_h_seg_ptr
=
limit_h_seg
;
*
limit_w_seg_ptr
=
limit_w_seg
;
return
true
;
}
#include "mlu_common_helper.h"
void
PSAMaskForwardMLUKernelLauncher
(
const
int
psa_type
,
const
Tensor
x
,
Tensor
y
,
const
int
num_
,
...
...
@@ -146,39 +17,7 @@ void PSAMaskForwardMLUKernelLauncher(const int psa_type, const Tensor x,
const
int
h_mask
,
const
int
w_mask
,
const
int
half_h_mask
,
const
int
half_w_mask
)
{
// params check
TORCH_CHECK
(
x
.
scalar_type
()
==
at
::
kFloat
,
"x type should be Float, got "
,
x
.
scalar_type
());
TORCH_CHECK
(
y
.
scalar_type
()
==
x
.
scalar_type
(),
"y should have the same type as x"
);
TORCH_CHECK
(
x
.
dim
()
==
4
,
"x should be a 4d tensor, got "
,
x
.
dim
(),
"D"
);
TORCH_CHECK
(
y
.
dim
()
==
4
,
"y should be a 4d tensor, got "
,
y
.
dim
(),
"D"
);
int
x_c
=
x
.
size
(
1
);
int
y_c
=
y
.
size
(
1
);
TORCH_CHECK
(
h_mask
*
w_mask
==
x_c
,
"channel of x should be the same as h_mask * w_mask"
);
TORCH_CHECK
(
h_feature
*
w_feature
==
y_c
,
"channel of y should be the same as h_feature * w_feature"
);
TORCH_CHECK
(
psa_type
==
0
||
psa_type
==
1
,
"psa_type only supports 'COLLECT' and 'DISTRIBUTE' currently"
);
if
(
x
.
numel
()
==
0
)
{
CNLOG
(
INFO
)
<<
"skip zero-element tensor"
;
return
;
}
cnrtFunctionType_t
k_type
=
CNRT_FUNC_TYPE_UNION1
;
cnrtDim3_t
k_dim
;
PartitionSeg
partition_info
;
policyFunc
(
&
k_dim
,
&
k_type
,
&
partition_info
,
num_
,
h_feature
);
int
n_limit_seg
,
h_limit_seg
,
w_limit_seg
;
bool
ret
=
findLimit
(
partition_info
.
n_per_core
,
partition_info
.
h_per_core
,
w_feature
,
x_c
,
y_c
,
&
n_limit_seg
,
&
h_limit_seg
,
&
w_limit_seg
,
psa_type
);
if
(
ret
!=
true
)
{
return
;
}
auto
memory_format
=
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
x
.
dim
());
...
...
@@ -186,22 +25,18 @@ void PSAMaskForwardMLUKernelLauncher(const int psa_type, const Tensor x,
at
::
Tensor
y_tmp
=
at
::
empty
({
num_
,
y_c
,
h_feature
,
w_feature
},
x
.
options
(),
memory_format
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
MluOpTensorDescriptor
x_desc
,
y_desc
;
x_desc
.
set_with_layout
(
x_tensor
,
MLUOP_LAYOUT_NHWC
);
y_desc
.
set_with_layout
(
y_tmp
,
MLUOP_LAYOUT_NHWC
);
// get ptr of tensors
auto
handle
=
mluOpGetCurrentHandle
();
auto
x_impl
=
torch_mlu
::
getMluTensorImpl
(
x_tensor
);
auto
x_ptr
=
x_impl
->
cnnlMalloc
();
auto
y_impl
=
torch_mlu
::
getMluTensorImpl
(
y_tmp
);
auto
y_ptr
=
y_impl
->
cnnlMalloc
();
KernelPsamaskForward
(
k_dim
,
k_type
,
queue
,
x_ptr
,
y_ptr
,
(
PsamaskType
)
psa_type
,
partition_info
.
core_partition
,
partition_info
.
cluster_partition
,
num_
,
h_feature
,
w_feature
,
h_mask
,
w_mask
,
x_c
,
y_c
,
half_h_mask
,
half_w_mask
,
partition_info
.
n_per_core
,
partition_info
.
h_per_core
,
partition_info
.
n_per_cluster
,
partition_info
.
h_per_cluster
,
n_limit_seg
,
h_limit_seg
,
w_limit_seg
);
mluOpPsamaskForward
(
handle
,
psa_type
,
x_desc
.
desc
(),
x_ptr
,
h_mask
,
w_mask
,
y_desc
.
desc
(),
y_ptr
);
y
.
copy_
(
y_tmp
);
}
...
...
@@ -212,39 +47,7 @@ void PSAMaskBackwardMLUKernelLauncher(const int psa_type, const Tensor dy,
const
int
h_mask
,
const
int
w_mask
,
const
int
half_h_mask
,
const
int
half_w_mask
)
{
// params check
TORCH_CHECK
(
dy
.
scalar_type
()
==
at
::
kFloat
,
"dy type should be Float, got "
,
dy
.
scalar_type
());
TORCH_CHECK
(
dx
.
scalar_type
()
==
dy
.
scalar_type
(),
"dx should have the same type as dy"
);
TORCH_CHECK
(
dy
.
dim
()
==
4
,
"dy should be a 4d tensor, got "
,
dy
.
dim
(),
"D"
);
TORCH_CHECK
(
dx
.
dim
()
==
4
,
"dx should be a 4d tensor, got "
,
dx
.
dim
(),
"D"
);
int
dy_c
=
dy
.
size
(
1
);
int
dx_c
=
dx
.
size
(
1
);
TORCH_CHECK
(
h_feature
*
w_feature
==
dy_c
,
"channel of dy should be the same as h_feature * w_feature"
);
TORCH_CHECK
(
h_mask
*
w_mask
==
dx_c
,
"channel of dx should be the same as h_mask * w_mask"
);
TORCH_CHECK
(
psa_type
==
0
||
psa_type
==
1
,
"psa_type only supports 'COLLECT' and 'DISTRIBUTE' currently"
);
if
(
dx
.
numel
()
==
0
)
{
CNLOG
(
INFO
)
<<
"skip zero-element tensor"
;
return
;
}
cnrtFunctionType_t
k_type
=
CNRT_FUNC_TYPE_UNION1
;
cnrtDim3_t
k_dim
;
PartitionSeg
partition_info
;
policyFunc
(
&
k_dim
,
&
k_type
,
&
partition_info
,
num_
,
h_feature
);
int
n_limit_seg
,
h_limit_seg
,
w_limit_seg
;
bool
ret
=
findLimit
(
partition_info
.
n_per_core
,
partition_info
.
h_per_core
,
w_feature
,
dx_c
,
dy_c
,
&
n_limit_seg
,
&
h_limit_seg
,
&
w_limit_seg
,
psa_type
);
if
(
ret
!=
true
)
{
return
;
}
auto
memory_format
=
torch_mlu
::
cnnl
::
ops
::
get_channels_last_memory_format
(
dy
.
dim
());
...
...
@@ -252,8 +55,11 @@ void PSAMaskBackwardMLUKernelLauncher(const int psa_type, const Tensor dy,
at
::
Tensor
dx_tmp
=
at
::
empty
({
num_
,
dx_c
,
h_feature
,
w_feature
},
dy
.
options
(),
memory_format
);
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
MluOpTensorDescriptor
dy_desc
,
dx_tmp_desc
;
dy_desc
.
set_with_layout
(
dy_tensor
,
MLUOP_LAYOUT_NHWC
);
dx_tmp_desc
.
set_with_layout
(
dx_tmp
,
MLUOP_LAYOUT_NHWC
);
auto
handle
=
mluOpGetCurrentHandle
();
// get ptr of tensors
auto
dx_impl
=
torch_mlu
::
getMluTensorImpl
(
dx_tmp
);
...
...
@@ -261,13 +67,8 @@ void PSAMaskBackwardMLUKernelLauncher(const int psa_type, const Tensor dy,
auto
dy_impl
=
torch_mlu
::
getMluTensorImpl
(
dy_tensor
);
auto
dy_ptr
=
dy_impl
->
cnnlMalloc
();
KernelPsamaskBackward
(
k_dim
,
k_type
,
queue
,
dy_ptr
,
dx_ptr
,
(
PsamaskType
)
psa_type
,
partition_info
.
core_partition
,
partition_info
.
cluster_partition
,
num_
,
h_feature
,
w_feature
,
h_mask
,
w_mask
,
dx_c
,
dy_c
,
half_h_mask
,
half_w_mask
,
partition_info
.
n_per_core
,
partition_info
.
h_per_core
,
partition_info
.
n_per_cluster
,
partition_info
.
h_per_cluster
,
n_limit_seg
,
h_limit_seg
,
w_limit_seg
);
mluOpPsamaskBackward
(
handle
,
psa_type
,
dy_desc
.
desc
(),
dy_ptr
,
h_mask
,
w_mask
,
dx_tmp_desc
.
desc
(),
dx_ptr
);
dx
.
copy_
(
dx_tmp
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment