Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
6e9ee267
Unverified
Commit
6e9ee267
authored
Aug 28, 2023
by
Chris Jiang
Committed by
GitHub
Aug 28, 2023
Browse files
[Refactor] Replace tin_shift op of MLU backend with mlu-ops (#2910)
parent
8b8bf5e1
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
76 additions
and
444 deletions
+76
-444
mmcv/ops/csrc/common/mlu/tin_shift_mlu_kernel.mlu
mmcv/ops/csrc/common/mlu/tin_shift_mlu_kernel.mlu
+0
-307
mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp
+14
-14
mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h
mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h
+2
-2
mmcv/ops/csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp
+15
-11
mmcv/ops/csrc/pytorch/mlu/tin_shift_mlu.cpp
mmcv/ops/csrc/pytorch/mlu/tin_shift_mlu.cpp
+45
-110
No files found.
mmcv/ops/csrc/common/mlu/tin_shift_mlu_kernel.mlu
deleted
100644 → 0
View file @
8b8bf5e1
/*************************************************************************
* Copyright (C) 2022 Cambricon.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "common_mlu_helper.hpp"
__nram__ char data_nram[MAX_NRAM_SIZE];
template <typename T>
__mlu_func__ void mluMultiKernelTinShift(
const T *input, const int *shifts, T *output, const int batch_size,
const int time_size, const int channel_size, const int hw_size,
const int group_size, const int group_channel) {
for (int cur_channel_index = taskId;
cur_channel_index < batch_size * channel_size;
cur_channel_index += taskDim) {
int n_index = cur_channel_index / channel_size;
int group_id = cur_channel_index % channel_size / group_channel;
int t_shift = shifts[n_index * group_size + group_id];
int index = cur_channel_index % channel_size * hw_size +
n_index * time_size * channel_size * hw_size;
__bang_write_value(data_nram, MAX_NRAM_SIZE, (char)0);
__asm__ volatile("sync;");
if (abs(t_shift) >= time_size) {
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
time_size - 1);
} else {
if (t_shift > 0) {
__memcpy(data_nram + t_shift * hw_size * sizeof(T), input + index,
hw_size * sizeof(T), GDRAM2NRAM, hw_size * sizeof(T),
channel_size * hw_size * sizeof(T), time_size - 1 - t_shift);
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
time_size - 1);
} else {
__memcpy(data_nram, input + (index - t_shift * channel_size * hw_size),
hw_size * sizeof(T), GDRAM2NRAM, hw_size * sizeof(T),
channel_size * hw_size * sizeof(T), time_size - 1 + t_shift);
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
time_size - 1);
}
}
__asm__ volatile("sync;");
}
}
template <typename T>
__mlu_func__ void mluHwSplit(const T *input, const int t_shift,
const int time_size, const int hw_size,
const int channel_size, const int index,
const int cur_sequence_index,
const int max_length_per_core, T *output) {
for (int cur_index = index; cur_index < index + hw_size;
cur_index += max_length_per_core) {
int memcpy_size = max_length_per_core;
if (cur_index + max_length_per_core > index + hw_size) {
memcpy_size = index + hw_size - cur_index;
}
if (cur_sequence_index - t_shift < 0 ||
cur_sequence_index - t_shift >= time_size) {
__memcpy(output + cur_index, data_nram, memcpy_size * sizeof(T),
NRAM2GDRAM);
} else {
__memcpy(data_nram, input + cur_index - t_shift * channel_size * hw_size,
memcpy_size * sizeof(T), GDRAM2NRAM);
__memcpy(output + cur_index, data_nram, memcpy_size * sizeof(T),
NRAM2GDRAM);
}
__asm__ volatile("sync;");
}
}
template <typename T>
__mlu_func__ void mluMultiKernelTinShiftSplitSequence(
const T *input, const int *shifts, T *output, const int batch_size,
const int time_size, const int channel_size, const int hw_size,
const int group_size, const int group_channel,
const int max_number_hw_per_core, const int max_length_per_core) {
const int tmp_max_number_hw_per_core =
max_number_hw_per_core > 0 ? max_number_hw_per_core : 1;
const int loop_time = time_size / tmp_max_number_hw_per_core +
((time_size % tmp_max_number_hw_per_core) > 0 ? 1 : 0);
int segmentime_size = tmp_max_number_hw_per_core;
int res_segment = time_size % tmp_max_number_hw_per_core;
for (int cur_segment_index = taskId;
cur_segment_index < loop_time * batch_size * channel_size;
cur_segment_index += taskDim) {
int n_index = cur_segment_index / loop_time / channel_size;
int group_id = cur_segment_index / loop_time % channel_size / group_channel;
int t_shift = shifts[n_index * group_size + group_id];
int index = n_index * time_size * channel_size * hw_size +
(cur_segment_index / loop_time % channel_size) * hw_size +
cur_segment_index % loop_time * segmentime_size * hw_size *
channel_size;
char *dst_gdram2nram = data_nram;
const T *src_gdram2nram = input + index;
int count_gdram2nram = -1;
int count_nram2gdram = -1;
int next_sequence_index =
index / hw_size / channel_size % time_size + segmentime_size;
int cur_sequence_index = index / hw_size / channel_size % time_size;
__bang_write_value(data_nram, MAX_NRAM_SIZE, (char)0);
__asm__ volatile("sync;");
if (max_number_hw_per_core == 0) {
mluHwSplit(input, t_shift, time_size, hw_size, channel_size, index,
cur_sequence_index, max_length_per_core, output);
continue;
}
if (abs(t_shift) >= time_size) {
if ((cur_segment_index + 1) % loop_time == 0 && res_segment != 0) {
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
res_segment - 1);
} else {
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
segmentime_size - 1);
}
continue;
}
if (t_shift == 0) {
if ((cur_segment_index + 1) % loop_time == 0 && res_segment != 0) {
dst_gdram2nram = data_nram;
src_gdram2nram = input + index;
count_gdram2nram = res_segment - 1;
count_nram2gdram = res_segment - 1;
} else {
dst_gdram2nram = data_nram;
src_gdram2nram = input + index;
count_gdram2nram = segmentime_size - 1;
count_nram2gdram = segmentime_size - 1;
}
} else if (t_shift > 0) {
int first_index_cur_channel =
n_index * time_size * channel_size * hw_size +
(cur_segment_index / loop_time % channel_size) * hw_size;
if ((cur_segment_index + 1) % loop_time == 0 && res_segment != 0) {
dst_gdram2nram = data_nram;
src_gdram2nram =
input +
(index - t_shift * channel_size * hw_size < first_index_cur_channel
? first_index_cur_channel
: index - t_shift * channel_size * hw_size);
count_gdram2nram = res_segment - 1;
count_nram2gdram = res_segment - 1;
if (cur_sequence_index < t_shift && t_shift < next_sequence_index) {
dst_gdram2nram =
data_nram + t_shift % segmentime_size * hw_size * sizeof(T);
count_gdram2nram = res_segment - (t_shift - cur_sequence_index) - 1;
}
} else {
if (t_shift >= next_sequence_index) {
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
segmentime_size - 1);
continue;
} else if (cur_sequence_index < t_shift &&
t_shift < next_sequence_index) {
dst_gdram2nram =
data_nram + t_shift % segmentime_size * hw_size * sizeof(T);
src_gdram2nram = input + first_index_cur_channel;
count_gdram2nram = segmentime_size - (t_shift % segmentime_size) - 1;
count_nram2gdram = segmentime_size - 1;
} else {
dst_gdram2nram = data_nram;
src_gdram2nram = input + index - t_shift * channel_size * hw_size;
count_gdram2nram = segmentime_size - 1;
count_nram2gdram = segmentime_size - 1;
}
}
} else {
int offset_index = time_size + t_shift;
if (cur_sequence_index >= offset_index) {
if ((cur_segment_index + 1) % loop_time == 0 && res_segment != 0) {
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
res_segment - 1);
continue;
} else {
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
segmentime_size - 1);
continue;
}
} else {
dst_gdram2nram = data_nram;
src_gdram2nram = input + index - t_shift * channel_size * hw_size;
if (cur_sequence_index - t_shift + segmentime_size < time_size) {
count_gdram2nram = segmentime_size - 1;
count_nram2gdram = segmentime_size - 1;
} else {
count_gdram2nram = time_size - (cur_sequence_index - t_shift) - 1;
count_nram2gdram =
(segmentime_size - 1) < (time_size - cur_sequence_index - 1)
? (segmentime_size - 1)
: (time_size - cur_sequence_index - 1);
}
}
}
__memcpy(dst_gdram2nram, src_gdram2nram, hw_size * sizeof(T), GDRAM2NRAM,
hw_size * sizeof(T), channel_size * hw_size * sizeof(T),
count_gdram2nram);
__memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM,
channel_size * hw_size * sizeof(T), hw_size * sizeof(T),
count_nram2gdram);
__asm__ volatile("sync;");
}
}
__mlu_entry__ void MLUUnion1KernelTinShift(
const void *input, const void *shifts, void *output, const int batch_size,
const int time_size, const int channel_size, const int hw_size,
const int group_size, const int group_channel,
const cnrtDataType_t data_dtype) {
// make sure that memcore is not used
if (coreId == 0x80) {
return;
}
switch (data_dtype) {
case CNRT_FLOAT16: {
mluMultiKernelTinShift((half *)input, (const int *)shifts, (half *)output,
batch_size, time_size, channel_size, hw_size,
group_size, group_channel);
}; break;
case CNRT_FLOAT32: {
mluMultiKernelTinShift((float *)input, (const int *)shifts,
(float *)output, batch_size, time_size,
channel_size, hw_size, group_size, group_channel);
}; break;
default: { return; }
}
}
__mlu_entry__ void MLUUnion1KernelTinShiftSplitSequence(
const void *input, const void *shifts, void *output, const int batch_size,
const int time_size, const int channel_size, const int hw_size,
const int group_size, const int group_channel,
const int max_number_hw_per_core, const int max_length_per_core,
const cnrtDataType_t data_dtype) {
// make sure that memcore is not used
if (coreId == 0x80) {
return;
}
switch (data_dtype) {
case CNRT_FLOAT16: {
mluMultiKernelTinShiftSplitSequence(
(half *)input, (const int *)shifts, (half *)output, batch_size,
time_size, channel_size, hw_size, group_size, group_channel,
max_number_hw_per_core, max_length_per_core);
}; break;
case CNRT_FLOAT32: {
mluMultiKernelTinShiftSplitSequence(
(float *)input, (const int *)shifts, (float *)output, batch_size,
time_size, channel_size, hw_size, group_size, group_channel,
max_number_hw_per_core, max_length_per_core);
}; break;
default: { return; }
}
}
void KernelTinShiftForward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const void *input, const void *shifts, void *output, const int batch_size,
const int time_size, const int channel_size, const int hw_size,
const int group_size, const int group_channel,
const cnrtDataType_t data_dtype, const int channel_per_core,
const int max_number_hw_per_core, const int max_length_per_core) {
if (channel_per_core >= 1) {
MLUUnion1KernelTinShift<<<k_dim, k_type, queue>>>(
input, shifts, output, batch_size, time_size, channel_size, hw_size,
group_size, group_channel, data_dtype);
} else {
MLUUnion1KernelTinShiftSplitSequence<<<k_dim, k_type, queue>>>(
input, shifts, output, batch_size, time_size, channel_size, hw_size,
group_size, group_channel, max_number_hw_per_core, max_length_per_core,
data_dtype);
}
}
void KernelTinShiftBackward(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
const void *grad_output, const void *shifts, void *grad_input,
const int batch_size, const int time_size, const int channel_size,
const int hw_size, const int group_size, const int group_channel,
const cnrtDataType_t data_dtype, const int channel_per_core,
const int max_number_hw_per_core, const int max_length_per_core) {
if (channel_per_core >= 1) {
MLUUnion1KernelTinShift<<<k_dim, k_type, queue>>>(
grad_output, shifts, grad_input, batch_size, time_size, channel_size,
hw_size, group_size, group_channel, data_dtype);
} else {
MLUUnion1KernelTinShiftSplitSequence<<<k_dim, k_type, queue>>>(
grad_output, shifts, grad_input, batch_size, time_size, channel_size,
hw_size, group_size, group_channel, max_number_hw_per_core,
max_length_per_core, data_dtype);
}
}
mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp
View file @
6e9ee267
...
...
@@ -14,9 +14,9 @@
#include "mlu_common_helper.h"
void
sigmoid_focal_loss_forward_mlu
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
Tensor
output
,
const
float
gamma
,
const
float
alpha
)
{
void
sigmoid_focal_loss_forward_mlu
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
Tensor
output
,
const
float
gamma
,
const
float
alpha
)
{
// params check
TORCH_CHECK
(
gamma
>=
0
,
"gamma should be greater than or equal to 0. "
,
"But now gamma is "
,
gamma
,
"."
);
...
...
@@ -82,15 +82,15 @@ void sigmoid_focal_loss_forward_mlu(Tensor input, Tensor target,
auto
handle
=
mluOpGetCurrentHandle
();
// launch kernel
TORCH_MLUOP_CHECK
(
mluOpFocalLossSigmoidForward
(
handle
,
prefer
,
reduction
,
input_desc
.
desc
(),
input_ptr
,
targe
t_desc
.
desc
(),
targe
t_ptr
,
weight_desc
.
desc
(),
weight_ptr
,
alpha
,
gamma
,
output_desc
.
desc
(),
output_ptr
));
TORCH_MLUOP_CHECK
(
mluOpFocalLossSigmoidForward
(
handle
,
prefer
,
reduction
,
inpu
t_desc
.
desc
(),
inpu
t_ptr
,
target_desc
.
desc
(),
target_ptr
,
weight_desc
.
desc
(),
weight_ptr
,
alpha
,
gamma
,
output_desc
.
desc
(),
output_ptr
));
}
void
sigmoid_focal_loss_backward_mlu
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
Tensor
output
,
const
float
gamma
,
const
float
alpha
)
{
void
sigmoid_focal_loss_backward_mlu
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
Tensor
output
,
const
float
gamma
,
const
float
alpha
)
{
// params check
TORCH_CHECK
(
gamma
>=
0
,
"gamma should be greater than or equal to 0. "
,
"But now gamma is "
,
gamma
,
"."
);
...
...
@@ -158,10 +158,10 @@ void sigmoid_focal_loss_backward_mlu(Tensor input, Tensor target,
auto
handle
=
mluOpGetCurrentHandle
();
// launch kernel
TORCH_MLUOP_CHECK
(
mluOpFocalLossSigmoidBackward
(
handle
,
prefer
,
reduction
,
input_desc
.
desc
(),
input_ptr
,
targe
t_desc
.
desc
(),
targe
t_ptr
,
weight_desc
.
desc
(),
weight_ptr
,
alpha
,
gamma
,
output_desc
.
desc
(),
output_ptr
));
TORCH_MLUOP_CHECK
(
mluOpFocalLossSigmoidBackward
(
handle
,
prefer
,
reduction
,
inpu
t_desc
.
desc
(),
inpu
t_ptr
,
target_desc
.
desc
(),
target_ptr
,
weight_desc
.
desc
(),
weight_ptr
,
alpha
,
gamma
,
output_desc
.
desc
(),
output_ptr
));
}
void
sigmoid_focal_loss_forward_impl
(
Tensor
input
,
Tensor
target
,
Tensor
weight
,
...
...
mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h
View file @
6e9ee267
...
...
@@ -18,8 +18,8 @@
#include "pytorch_device_registry.hpp"
#define MLUOP_MAJOR 0
#define MLUOP_MINOR
7
#define MLUOP_PATCHLEVEL
1
#define MLUOP_MINOR
8
#define MLUOP_PATCHLEVEL
0
/*************************************************************************
* This MACRO contains operations of simple tensor to mlu-tensor.
...
...
mmcv/ops/csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp
View file @
6e9ee267
...
...
@@ -74,8 +74,8 @@ void RoIPointPool3dForwardMLUKernelLauncher(
pts_feature
.
numel
(),
"."
);
// set contiguous
auto
xyz_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
xyz
,
xyz
.
suggest_memory_format
());
auto
xyz_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
xyz
,
xyz
.
suggest_memory_format
());
auto
pts_feature_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
pts_feature
,
pts_feature
.
suggest_memory_format
());
auto
boxes3d_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
...
...
@@ -92,13 +92,16 @@ void RoIPointPool3dForwardMLUKernelLauncher(
auto
pts_feature_ptr
=
pts_feature_impl
->
cnnlMalloc
();
auto
boxes3d_impl
=
torch_mlu
::
getMluTensorImpl
(
boxes3d_contiguous
);
auto
boxes3d_ptr
=
boxes3d_impl
->
cnnlMalloc
();
auto
pooled_features_impl
=
torch_mlu
::
getMluTensorImpl
(
pooled_features_contiguous
);
auto
pooled_features_impl
=
torch_mlu
::
getMluTensorImpl
(
pooled_features_contiguous
);
auto
pooled_features_ptr
=
pooled_features_impl
->
cnnlMalloc
();
auto
pooled_empty_flag_impl
=
torch_mlu
::
getMluTensorImpl
(
pooled_empty_flag_contiguous
);
auto
pooled_empty_flag_impl
=
torch_mlu
::
getMluTensorImpl
(
pooled_empty_flag_contiguous
);
auto
pooled_empty_flag_ptr
=
pooled_empty_flag_impl
->
cnnlMalloc
();
// create tensor descriptors
MluOpTensorDescriptor
xyz_desc
,
pts_feature_desc
,
boxes3d_desc
,
pooled_features_desc
,
pooled_empty_flag_desc
;
MluOpTensorDescriptor
xyz_desc
,
pts_feature_desc
,
boxes3d_desc
,
pooled_features_desc
,
pooled_empty_flag_desc
;
xyz_desc
.
set
(
xyz_contiguous
);
pts_feature_desc
.
set
(
pts_feature_contiguous
);
boxes3d_desc
.
set
(
boxes3d_contiguous
);
...
...
@@ -108,10 +111,11 @@ void RoIPointPool3dForwardMLUKernelLauncher(
// get workspace
size_t
workspace_size
=
0
;
auto
handle
=
mluOpGetCurrentHandle
();
TORCH_MLUOP_CHECK
(
mluOpGetRoiPointPool3dWorkspaceSize
(
handle
,
batch_size
,
pts_num
,
boxes_num
,
feature_in_len
,
sampled_pts_num
,
xyz_desc
.
desc
(),
pts_feature_desc
.
desc
(),
boxes3d_desc
.
desc
(),
pooled_features_desc
.
desc
(),
pooled_empty_flag_desc
.
desc
(),
&
workspace_size
));
TORCH_MLUOP_CHECK
(
mluOpGetRoiPointPool3dWorkspaceSize
(
handle
,
batch_size
,
pts_num
,
boxes_num
,
feature_in_len
,
sampled_pts_num
,
xyz_desc
.
desc
(),
pts_feature_desc
.
desc
(),
boxes3d_desc
.
desc
(),
pooled_features_desc
.
desc
(),
pooled_empty_flag_desc
.
desc
(),
&
workspace_size
));
auto
workspace
=
at
::
empty
(
workspace_size
,
xyz
.
options
().
dtype
(
at
::
kByte
));
auto
workspace_impl
=
torch_mlu
::
getMluTensorImpl
(
workspace
);
...
...
@@ -120,8 +124,8 @@ void RoIPointPool3dForwardMLUKernelLauncher(
handle
,
batch_size
,
pts_num
,
boxes_num
,
feature_in_len
,
sampled_pts_num
,
xyz_desc
.
desc
(),
xyz_ptr
,
pts_feature_desc
.
desc
(),
pts_feature_ptr
,
boxes3d_desc
.
desc
(),
boxes3d_ptr
,
workspace_ptr
,
workspace_size
,
pooled_features_desc
.
desc
(),
pooled_features_ptr
,
pooled_empty_flag_desc
.
desc
(),
(
int
*
)
pooled_empty_flag_ptr
));
pooled_features_desc
.
desc
(),
pooled_features_ptr
,
pooled_empty_flag_desc
.
desc
(),
(
int
*
)
pooled_empty_flag_ptr
));
}
void
roipoint_pool3d_forward_mlu
(
int
batch_size
,
int
pts_num
,
int
boxes_num
,
...
...
mmcv/ops/csrc/pytorch/mlu/tin_shift_mlu.cpp
View file @
6e9ee267
...
...
@@ -9,65 +9,7 @@
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "pytorch_device_registry.hpp"
#include "pytorch_mlu_helper.hpp"
void
KernelTinShiftForward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
void
*
input
,
const
void
*
shifts
,
void
*
output
,
const
int
batch_size
,
const
int
time_size
,
const
int
channel_size
,
const
int
hw_size
,
const
int
group_size
,
const
int
group_channel
,
const
cnrtDataType_t
data_dtype
,
const
int
channel_per_core
,
const
int
max_number_hw_per_core
,
const
int
max_length_per_core
);
void
KernelTinShiftBackward
(
cnrtDim3_t
k_dim
,
cnrtFunctionType_t
k_type
,
cnrtQueue_t
queue
,
const
void
*
grad_output
,
const
void
*
shifts
,
void
*
grad_input
,
const
int
batch_size
,
const
int
time_size
,
const
int
channel_size
,
const
int
hw_size
,
const
int
group_size
,
const
int
group_channel
,
const
cnrtDataType_t
data_dtype
,
const
int
channel_per_core
,
const
int
max_number_hw_per_core
,
const
int
max_length_per_core
);
// policy function
static
void
policyFunc
(
const
Tensor
&
input
,
cnrtDim3_t
*
k_dim
,
cnrtFunctionType_t
*
k_type
,
int
*
channel_per_core
,
int
*
max_number_hw_per_core
,
int
*
max_length_per_core
)
{
const
int32_t
cluster_limit
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrClusterCount
);
const
int32_t
core_limit
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrMcorePerCluster
);
auto
nram_size
=
torch_mlu
::
getDeviceAttr
(
cnrtAttrNramSizePerMcore
);
const
int
core_num
=
core_limit
*
cluster_limit
;
const
int
batch_size
=
input
.
size
(
0
);
const
int
time_size
=
input
.
size
(
1
);
const
int
channel_size
=
input
.
size
(
2
);
const
int
hw_size
=
input
.
size
(
3
);
const
size_t
size_per_channel
=
time_size
*
hw_size
*
input
.
itemsize
();
*
channel_per_core
=
nram_size
/
size_per_channel
;
int
task_dim
=
0
;
if
(
*
channel_per_core
==
0
)
{
const
size_t
size_per_hw
=
hw_size
*
input
.
itemsize
();
*
max_number_hw_per_core
=
nram_size
/
size_per_hw
;
if
(
*
max_number_hw_per_core
<=
0
)
{
*
max_length_per_core
=
nram_size
/
input
.
itemsize
();
}
int
tmp_max_number_hw_per_core
=
*
max_number_hw_per_core
>
0
?
*
max_number_hw_per_core
:
1
;
const
int
loop_time
=
(
time_size
/
(
tmp_max_number_hw_per_core
))
+
((
time_size
%
(
tmp_max_number_hw_per_core
))
>
0
?
1
:
0
);
task_dim
=
batch_size
*
channel_size
*
loop_time
<
core_num
?
batch_size
*
channel_size
*
loop_time
:
core_num
;
}
else
{
task_dim
=
batch_size
*
channel_size
<
core_num
?
batch_size
*
channel_size
:
core_num
;
}
k_dim
->
x
=
core_limit
;
k_dim
->
y
=
(
task_dim
/
core_limit
)
>
0
?
(
task_dim
/
core_limit
)
:
1
;
k_dim
->
z
=
1
;
*
k_type
=
CNRT_FUNC_TYPE_UNION1
;
}
#include "mlu_common_helper.h"
void
TINShiftForwardMLUKernelLauncher
(
Tensor
input
,
Tensor
shift
,
Tensor
output
)
{
...
...
@@ -89,40 +31,37 @@ void TINShiftForwardMLUKernelLauncher(Tensor input, Tensor shift,
if
(
input
.
size
(
1
)
==
0
)
{
return
;
}
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
int
channel_per_core
=
0
;
int
max_number_hw_per_core
=
0
;
int
max_length_per_core
=
0
;
policyFunc
(
input
,
&
k_dim
,
&
k_type
,
&
channel_per_core
,
&
max_number_hw_per_core
,
&
max_length_per_core
);
const
int
batch_size
=
input
.
size
(
0
);
const
int
time_size
=
input
.
size
(
1
);
const
int
channel_size
=
input
.
size
(
2
);
const
int
hw_size
=
input
.
size
(
3
);
const
int
group_size
=
shift
.
size
(
1
);
int
group_channel
=
channel_size
/
group_size
;
// get tensor impl
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
input
);
auto
shift_impl
=
torch_mlu
::
getMluTensorImpl
(
shift
);
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output
);
// set contiguous
auto
input_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
input
,
input
.
suggest_memory_format
());
auto
shift_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
shift
,
shift
.
suggest_memory_format
());
auto
output_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
output
,
output
.
suggest_memory_format
());
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// get tensor impl
auto
input_impl
=
torch_mlu
::
getMluTensorImpl
(
input_contiguous
);
auto
shift_impl
=
torch_mlu
::
getMluTensorImpl
(
shift_contiguous
);
auto
output_impl
=
torch_mlu
::
getMluTensorImpl
(
output_contiguous
);
// get the mlu ptr
auto
input_ptr
=
input_impl
->
cnnlMalloc
();
auto
shift_ptr
=
shift_impl
->
cnnlMalloc
();
auto
output_ptr
=
output_impl
->
cnnlMalloc
();
cnrtDataType_t
data_dtype
=
torch_mlu
::
toCnrtDtype
(
input
.
dtype
());
// set tensor descriptor
MluOpTensorDescriptor
input_desc
,
shift_desc
,
output_desc
;
input_desc
.
set
(
input_contiguous
);
shift_desc
.
set
(
shift_contiguous
);
output_desc
.
set
(
output_contiguous
);
// get current handle
auto
handle
=
mluOpGetCurrentHandle
();
KernelTinShiftForward
(
k_dim
,
k_type
,
queue
,
input_ptr
,
shift_ptr
,
output_ptr
,
batch_size
,
time_size
,
channel_size
,
hw_size
,
group_size
,
group_channel
,
data_dtype
,
channel_per_core
,
max_number_hw_per_core
,
max_length_per_core
);
TORCH_MLUOP_CHECK
(
mluOpTinShiftForward
(
handle
,
input_desc
.
desc
(),
input_ptr
,
shift_desc
.
desc
(),
shift_ptr
,
output_desc
.
desc
(),
output_ptr
));
}
void
TINShiftBackwardMLUKernelLauncher
(
Tensor
grad_output
,
Tensor
shift
,
...
...
@@ -148,41 +87,37 @@ void TINShiftBackwardMLUKernelLauncher(Tensor grad_output, Tensor shift,
if
(
grad_output
.
size
(
1
)
==
0
)
{
return
;
}
cnrtDim3_t
k_dim
;
cnrtFunctionType_t
k_type
;
int
channel_per_core
=
0
;
int
max_number_hw_per_core
=
0
;
int
max_length_per_core
=
0
;
policyFunc
(
grad_output
,
&
k_dim
,
&
k_type
,
&
channel_per_core
,
&
max_number_hw_per_core
,
&
max_length_per_core
);
const
int
batch_size
=
grad_output
.
size
(
0
);
const
int
time_size
=
grad_output
.
size
(
1
);
const
int
channel_size
=
grad_output
.
size
(
2
);
const
int
hw_size
=
grad_output
.
size
(
3
);
const
int
group_size
=
shift
.
size
(
1
);
int
group_channel
=
channel_size
/
group_size
;
// get tensor impl
auto
grad_output_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_output
);
auto
shift_impl
=
torch_mlu
::
getMluTensorImpl
(
shift
);
auto
grad_input_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_input
);
// set contiguous
auto
grad_output_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
grad_output
,
grad_output
.
suggest_memory_format
());
auto
shift_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
shift
,
shift
.
suggest_memory_format
());
auto
grad_input_contiguous
=
torch_mlu
::
cnnl
::
ops
::
cnnl_contiguous
(
grad_input
,
grad_input
.
suggest_memory_format
());
// get compute queue
auto
queue
=
torch_mlu
::
getCurQueue
();
// get tensor impl
auto
grad_output_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_output_contiguous
);
auto
shift_impl
=
torch_mlu
::
getMluTensorImpl
(
shift_contiguous
);
auto
grad_input_impl
=
torch_mlu
::
getMluTensorImpl
(
grad_input_contiguous
);
// get the mlu ptr
auto
grad_output_ptr
=
grad_output_impl
->
cnnlMalloc
();
auto
shift_ptr
=
shift_impl
->
cnnlMalloc
();
auto
grad_input_ptr
=
grad_input_impl
->
cnnlMalloc
();
cnrtDataType_t
data_dtype
=
torch_mlu
::
toCnrtDtype
(
grad_output
.
dtype
());
// set tensor descriptor
MluOpTensorDescriptor
grad_output_desc
,
shift_desc
,
grad_input_desc
;
grad_output_desc
.
set
(
grad_output_contiguous
);
shift_desc
.
set
(
shift_contiguous
);
grad_input_desc
.
set
(
grad_input_contiguous
);
// get current handle
auto
handle
=
mluOpGetCurrentHandle
();
KernelTinShiftBackward
(
k_dim
,
k_type
,
queue
,
grad_output_ptr
,
shift_ptr
,
grad_input_ptr
,
batch_size
,
time_size
,
channel_size
,
hw_size
,
group_size
,
group_channel
,
data_dtype
,
channel_per_core
,
max_number_hw_per_core
,
max_length_per_core
);
TORCH_MLUOP_CHECK
(
mluOpTinShiftBackward
(
handle
,
grad_output_desc
.
desc
(),
grad_output_ptr
,
shift_desc
.
desc
(),
shift_ptr
,
grad_input_desc
.
desc
(),
grad_input_ptr
));
}
void
tin_shift_forward_mlu
(
Tensor
input
,
Tensor
shift
,
Tensor
output
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment