Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5db79de0
"git@developer.sourcefind.cn:change/sglang.git" did not exist on "75f4ccb7ddea2fd1abaa6475855da141b6c63980"
Commit
5db79de0
authored
Jul 14, 2022
by
carlushuang
Browse files
add a direct bias-relu-add implementation
parent
5024f317
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
2135 additions
and
1 deletion
+2135
-1
example/cpu_02_conv2d_fwd_bias_relu_add/cpu_conv2d_fwd_bias_relu_add.cpp
...conv2d_fwd_bias_relu_add/cpu_conv2d_fwd_bias_relu_add.cpp
+8
-1
include/ck/tensor_operation/cpu/device/device_convnd_direct_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp
..._direct_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp
+1052
-0
include/ck/tensor_operation/cpu/grid/gridwise_direct_conv_bias_activation_add_avx2.hpp
...pu/grid/gridwise_direct_conv_bias_activation_add_avx2.hpp
+1005
-0
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
.../threadwise_tensor_slice_transfer_avx2_specialization.hpp
+6
-0
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_direct_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp
...wd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp
+64
-0
No files found.
example/cpu_02_conv2d_fwd_bias_relu_add/cpu_conv2d_fwd_bias_relu_add.cpp
View file @
5db79de0
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
#define TEST_FUSION_BIAS_RELU 1
#define TEST_FUSION_BIAS_RELU 1
#define TEST_FUSION_BIAS 2
#define TEST_FUSION_BIAS 2
#define TEST_FUSION_BIAS_ADD_RELU 3
#define TEST_FUSION_BIAS_ADD_RELU 3
#define TEST_FUSION TEST_FUSION_BIAS_
ADD_
RELU
#define TEST_FUSION TEST_FUSION_BIAS_RELU
_ADD
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXC_NHWK 0
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
#define TEST_LAYOUT_NHWC_KYXCK8_NHWK 1
...
@@ -171,6 +171,11 @@ void add_device_conv2d_fwd_bias_add_relu_avx2_nhwc_yxck_nhwk_mt(
...
@@ -171,6 +171,11 @@ void add_device_conv2d_fwd_bias_add_relu_avx2_nhwc_yxck_nhwk_mt(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddAddRelu
>>&
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddAddRelu
>>&
instances
);
instances
);
// ------------------ direct-conv nhwc-kcyxk8-nhwk
void
add_device_conv2d_direct_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PassThrough
,
PassThrough
,
AddReluAdd
>>&
instances
);
}
// namespace device_conv2d_fwd_bias_activation_add_avx2_instance
}
// namespace device_conv2d_fwd_bias_activation_add_avx2_instance
}
// namespace device
}
// namespace device
}
// namespace cpu
}
// namespace cpu
...
@@ -623,6 +628,8 @@ int main(int argc, char* argv[])
...
@@ -623,6 +628,8 @@ int main(int argc, char* argv[])
add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxck8_nhwk_local_c
(
add_device_conv2d_fwd_bias_relu_add_avx2_nhwc_kyxck8_nhwk_local_c
(
conv_ptrs
);
conv_ptrs
);
}
}
ck
::
tensor_operation
::
cpu
::
device
::
device_conv2d_fwd_bias_activation_add_avx2_instance
::
add_device_conv2d_direct_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk
(
conv_ptrs
);
#endif
#endif
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
#if TEST_LAYOUT == TEST_LAYOUT_NHWC_YXCK_NHWK
if
(
omp_get_max_threads
()
>
1
)
if
(
omp_get_max_threads
()
>
1
)
...
...
include/ck/tensor_operation/cpu/device/device_convnd_direct_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp
0 → 100644
View file @
5db79de0
#ifndef DEVICE_CONV2D_DIRECT_FWD_BIAS_ACTIVATION_ADD_AVX2_NHWC_KYXCK8_NHWK_HPP
#define DEVICE_CONV2D_DIRECT_FWD_BIAS_ACTIVATION_ADD_AVX2_NHWC_KYXCK8_NHWK_HPP
#include <iostream>
#include <sstream>
#include <numeric>
#include <memory>
#include <vector>
#include "ck/device_utility/kernel_launch.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/tensor_operation/cpu/device/device_base_cpu.hpp"
#include "ck/tensor_operation/cpu/device/device_conv_fwd_cpu.hpp"
#include "ck/tensor_operation/cpu/device/convolution_forward_specialization_cpu.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/cpu/grid/gridwise_direct_conv_bias_activation_add_avx2.hpp"
#include "ck/tensor_operation/cpu/thread/threadwise_gemm_avx2.hpp"
#include "ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
cpu
{
namespace
device
{
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
BiasDataType
,
typename
AddDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionForwardSpecialization_t
ConvForwardSpecialization
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
MPerThread
,
ck
::
index_t
NPerThread
,
bool
UseALocalBuffer
,
bool
UseBLocalBuffer
,
bool
UseCLocalBuffer
,
bool
FuseBias
,
bool
FuseAdd
,
bool
BiasAlongGemmM
>
struct
DeviceConvNDDirectFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
:
public
DeviceConvFwdBiasActivationAdd
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
using
DeviceOp
=
DeviceConvNDDirectFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
;
using
ADataType
=
InDataType
;
using
BDataType
=
WeiDataType
;
using
CDataType
=
OutDataType
;
using
C0DataType
=
BiasDataType
;
using
C1DataType
=
AddDataType
;
using
AElementwiseOperation
=
InElementwiseOperation
;
using
BElementwiseOperation
=
WeiElementwiseOperation
;
using
CElementwiseOperation
=
OutElementwiseOperation
;
// TODO make A/B datatype different
using
ABDataType
=
InDataType
;
static
constexpr
index_t
NDimSpatial
=
NumDimSpatial
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
bool
NonTemporalStore
=
false
;
DeviceConvNDDirectFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
(
const
DeviceConvFwdDynamicTunable
&
dtune
)
:
gridwise_gemm
(
dtune
)
{
}
static
constexpr
auto
GetThreadwiseGemm_Dispatch
()
{
if
constexpr
(
MPerThread
==
4
&&
NPerThread
==
24
)
{
return
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_4x24_Dispatch
<
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
NonTemporalStore
>
{};
}
else
if
constexpr
(
MPerThread
==
6
&&
NPerThread
==
16
)
{
return
ck
::
cpu
::
ThreadwiseGemmAvx2_MxN_6x16_Dispatch
<
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
NonTemporalStore
>
{};
}
else
{
// static_assert(false, "invalid Mr/Nr");
}
}
using
ThreadwiseGemm_Dispatch
=
decltype
(
GetThreadwiseGemm_Dispatch
());
static
auto
GetWeightTensorDescriptor
(
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_n
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_n
/
8
,
gemm_k
,
8
));
}
static
auto
GetOutputTensorDescriptor
(
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_n
)
{
const
auto
out_gemm_m_n_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_n
));
return
out_gemm_m_n_grid_desc
;
}
static
auto
MakeBiasTensorDescriptor
(
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_n
)
{
if
constexpr
(
BiasAlongGemmM
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
));
}
else
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_n
));
}
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
GetInputTensorDescriptor
(
ck
::
index_t
N
,
ck
::
index_t
C
,
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_k
,
const
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_strides
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_dilations
,
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
{
const
index_t
Wi
=
input_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
const
auto
in_gemm_m_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
return
in_gemm_m_k_grid_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
const
auto
in_n_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
const
auto
in_n_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemm_m_k_grid_desc
;
}
else
{
const
index_t
X
=
filter_spatial_lengths
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
0
];
const
auto
in_n_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Wi
,
C
));
const
auto
in_n_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_n_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_merge_transform
(
make_tuple
(
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemm_m_k_grid_desc
;
}
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
GetInputTensorDescriptor
(
ck
::
index_t
N
,
ck
::
index_t
C
,
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_k
,
const
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_strides
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_dilations
,
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
{
const
index_t
Hi
=
input_spatial_lengths
[
0
];
const
index_t
Wi
=
input_spatial_lengths
[
1
];
const
index_t
Ho
=
output_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
1
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
const
auto
in_gemm_m_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
return
in_gemm_m_k_grid_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
const
auto
in_n_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
const
auto
in_n_ho_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_ho_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemm_m_k_grid_desc
;
}
else
{
const
index_t
Y
=
filter_spatial_lengths
[
0
];
const
index_t
X
=
filter_spatial_lengths
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
const
auto
in_n_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemm_m_k_grid_desc
;
}
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
GetInputTensorDescriptor
(
ck
::
index_t
N
,
ck
::
index_t
C
,
ck
::
index_t
gemm_m
,
ck
::
index_t
gemm_k
,
ck
::
index_t
gemm_m_pad
,
const
std
::
vector
<
ck
::
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_strides
,
const
std
::
vector
<
ck
::
index_t
>&
conv_filter_dilations
,
const
std
::
vector
<
ck
::
index_t
>&
input_left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
input_right_pads
)
{
const
index_t
Di
=
input_spatial_lengths
[
0
];
const
index_t
Hi
=
input_spatial_lengths
[
1
];
const
index_t
Wi
=
input_spatial_lengths
[
2
];
const
index_t
Do
=
output_spatial_lengths
[
0
];
const
index_t
Ho
=
output_spatial_lengths
[
1
];
const
index_t
Wo
=
output_spatial_lengths
[
2
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
2
];
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
const
auto
in_gemm_m_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
gemm_m
,
gemm_k
));
return
in_gemm_m_k_grid_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
const
auto
in_n_di_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
const
auto
in_n_do_ho_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Do
),
make_tuple
(
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_do_ho_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemm_m_k_grid_desc
;
}
else
{
const
index_t
Z
=
filter_spatial_lengths
[
0
];
const
index_t
Y
=
filter_spatial_lengths
[
1
];
const
index_t
X
=
filter_spatial_lengths
[
2
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
const
index_t
InRightPadD
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
const
auto
in_n_di_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
));
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_z_do_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
in_gemm_m_k_grid_desc
=
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemm_m_k_grid_desc
;
}
}
static
index_t
GetGemmM
(
ck
::
index_t
N
,
const
std
::
vector
<
ck
::
index_t
>&
output_spatial_lengths
)
{
return
N
*
std
::
accumulate
(
std
::
begin
(
output_spatial_lengths
),
std
::
end
(
output_spatial_lengths
),
1
,
std
::
multiplies
<
ck
::
index_t
>
());
}
static
index_t
GetGemmK
(
ck
::
index_t
C
,
const
std
::
vector
<
ck
::
index_t
>&
filter_spatial_lengths
)
{
return
C
*
std
::
accumulate
(
std
::
begin
(
filter_spatial_lengths
),
std
::
end
(
filter_spatial_lengths
),
1
,
std
::
multiplies
<
ck
::
index_t
>
());
}
static
index_t
GetGemmN
(
ck
::
index_t
K
)
{
// return ck::math::integer_least_multiple(K,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize);
return
K
;
}
static
auto
MakeABCGridDescriptor
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
{
using
namespace
ck
;
const
index_t
GemmM
=
GetGemmM
(
N
,
output_spatial_lengths
);
const
index_t
GemmN
=
GetGemmN
(
K
);
const
index_t
GemmK
=
GetGemmK
(
C
,
filter_spatial_lengths
);
// A:
const
auto
in_gemm_m_k_grid_desc
=
GetInputTensorDescriptor
<
NumDimSpatial
>
(
N
,
C
,
GemmM
,
GemmK
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
// B:
const
auto
wei_gemm_n0_k_n1_grid_desc
=
GetWeightTensorDescriptor
(
GemmK
,
GemmN
);
// C:
const
auto
out_gemm_m_n_grid_desc
=
GetOutputTensorDescriptor
(
GemmM
,
GemmN
);
return
make_tuple
(
in_gemm_m_k_grid_desc
,
wei_gemm_n0_k_n1_grid_desc
,
out_gemm_m_n_grid_desc
);
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
{
return
MakeABCGridDescriptor
(
1
,
1
,
1
,
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
});
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
{
return
MakeABCGridDescriptor
(
1
,
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
});
}
template
<
ck
::
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
{
return
MakeABCGridDescriptor
(
1
,
1
,
1
,
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
});
}
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NumDimSpatial
>
());
using
AGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
BGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
using
C0GridDesc
=
remove_cvref_t
<
decltype
(
MakeBiasTensorDescriptor
(
1
,
1
))
>
;
using
C1GridDesc
=
CGridDesc
;
static
constexpr
auto
GetInputBlockDescriptor
()
{
if
constexpr
(
UseALocalBuffer
)
{
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, KPerBlock));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
else
{
return
AGridDesc
{};
}
}
static
constexpr
auto
GetWeightBlockDescriptor
()
{
if
constexpr
(
UseBLocalBuffer
)
{
// return make_naive_tensor_descriptor_packed(make_tuple(
// math::integer_divide_ceil(NPerBlock,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize), KPerBlock,
// ThreadwiseGemm_Dispatch::MatrixBMinVectorSize));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
,
0
));
}
else
{
return
BGridDesc
{};
}
}
static
constexpr
auto
GetOutputBlockDescriptor
()
{
if
constexpr
(
UseCLocalBuffer
)
{
// return make_naive_tensor_descriptor_packed(make_tuple(MPerBlock, NPerBlock));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
0
,
0
));
}
else
{
return
CGridDesc
{};
}
}
// static constexpr bool UseCLocalBuffer = false;
using
AThreadwiseCopy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_In_NHWC
<
ADataType
,
ADataType
,
AGridDesc
,
decltype
(
GetInputBlockDescriptor
()),
InElementwiseOperation
,
!
UseALocalBuffer
,
ConvForwardSpecialization
>
;
using
BThreadwiseCopy
=
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_ConvFwd_Wei_KYXCK8
<
BDataType
,
BDataType
,
BGridDesc
,
decltype
(
GetWeightBlockDescriptor
()),
WeiElementwiseOperation
,
!
UseBLocalBuffer
,
ConvForwardSpecialization
>
;
static
constexpr
auto
GetCThreadwiseCopy
()
{
constexpr
ck
::
index_t
C_nDim
=
CGridDesc
::
GetNumOfDimension
();
if
constexpr
(
FuseBias
&&
FuseAdd
)
{
return
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN
<
CDataType
,
C0DataType
,
C1DataType
,
CDataType
,
CGridDesc
,
C0GridDesc
,
C1GridDesc
,
decltype
(
GetOutputBlockDescriptor
()),
OutElementwiseOperation
,
!
UseCLocalBuffer
,
BiasAlongGemmM
>
(
CGridDesc
{},
ck
::
make_zero_multi_index
<
C_nDim
>
(),
GetOutputBlockDescriptor
(),
ck
::
make_zero_multi_index
<
C_nDim
>
(),
OutElementwiseOperation
{});
}
else
if
constexpr
(
FuseBias
&&
!
FuseAdd
)
{
return
ck
::
cpu
::
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN
<
CDataType
,
C0DataType
,
C1DataType
,
CDataType
,
CGridDesc
,
C0GridDesc
,
C1GridDesc
,
decltype
(
GetOutputBlockDescriptor
()),
OutElementwiseOperation
,
!
UseCLocalBuffer
,
BiasAlongGemmM
>
(
CGridDesc
{},
ck
::
make_zero_multi_index
<
C_nDim
>
(),
GetOutputBlockDescriptor
(),
ck
::
make_zero_multi_index
<
C_nDim
>
(),
OutElementwiseOperation
{});
}
}
using
CThreadwiseCopy
=
decltype
(
GetCThreadwiseCopy
());
using
GridwiseGemm
=
ck
::
cpu
::
GridwiseDirectConvNHWCBiasActivationAddAvx2
<
ADataType
,
// InDataType,
BDataType
,
// WeiDataType,
CDataType
,
// OutDataType,
C0DataType
,
// C0DataType
C1DataType
,
// C1DataType
AGridDesc
,
// AGridDesc,
BGridDesc
,
// BGridDesc,
CGridDesc
,
// CGridDesc,
C0GridDesc
,
// C0GridDesc,
C1GridDesc
,
// C1GridDesc,
AElementwiseOperation
,
// AElementwiseOperation,
BElementwiseOperation
,
// BElementwiseOperation,
CElementwiseOperation
,
// CElementwiseOperation,
ThreadwiseGemm_Dispatch
,
// ThreadwiseGemm_Dispatch,
AThreadwiseCopy
,
// AThreadwiseCopy
BThreadwiseCopy
,
// BThreadwiseCopy
CThreadwiseCopy
,
// CThreadwiseCopy
ck
::
Sequence
<
0
,
1
>
,
// ThreadMNAccessOrder
UseALocalBuffer
,
// UseALocalBuffer
UseBLocalBuffer
,
// UseBLocalBuffer
UseCLocalBuffer
// UseCLocalBuffer
>
;
GridwiseGemm
gridwise_gemm
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
InDataType
*
p_in_grid
,
const
WeiDataType
*
p_wei_grid
,
OutDataType
*
p_out_grid
,
const
BiasDataType
*
p_bias_grid
,
const
AddDataType
*
p_add_grid
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
:
p_a_grid_
{
p_in_grid
},
p_b_grid_
{
p_wei_grid
},
p_c_grid_
{
p_out_grid
},
p_c0_grid_
{
p_bias_grid
},
p_c1_grid_
{
p_add_grid
},
a_grid_desc_
{},
b_grid_desc_
{},
c_grid_desc_
{},
c0_grid_desc_
{},
c1_grid_desc_
{},
a_element_op_
{
in_element_op
},
b_element_op_
{
wei_element_op
},
c_element_op_
{
out_element_op
},
Conv_N_
{
N
},
Conv_K_
{
K
},
Conv_C_
{
C
},
input_spatial_lengths_
{
input_spatial_lengths
},
filter_spatial_lengths_
{
filter_spatial_lengths
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
output_spatial_lengths_
{
output_spatial_lengths
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
{
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor
(
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
a_grid_desc_
=
descs
[
I0
];
b_grid_desc_
=
descs
[
I1
];
c_grid_desc_
=
descs
[
I2
];
c0_grid_desc_
=
DeviceOp
::
MakeBiasTensorDescriptor
(
GetGemmM
(
N
,
output_spatial_lengths
),
GetGemmN
(
K
));
c1_grid_desc_
=
descs
[
I2
];
}
// private:
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
const
C0DataType
*
p_c0_grid_
;
const
C1DataType
*
p_c1_grid_
;
AGridDesc
a_grid_desc_
;
BGridDesc
b_grid_desc_
;
CGridDesc
c_grid_desc_
;
C0GridDesc
c0_grid_desc_
;
C1GridDesc
c1_grid_desc_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
// for checking IsSupportedArgument()
index_t
Conv_N_
;
index_t
Conv_K_
;
index_t
Conv_C_
;
std
::
vector
<
index_t
>
input_spatial_lengths_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
index_t
>
output_spatial_lengths_
;
std
::
vector
<
index_t
>
conv_filter_strides_
;
std
::
vector
<
index_t
>
conv_filter_dilations_
;
std
::
vector
<
index_t
>
input_left_pads_
;
std
::
vector
<
index_t
>
input_right_pads_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
GridwiseGemm
gridwise_gemm
;
Invoker
(
const
GridwiseGemm
&
gridwise_gemm_
)
:
gridwise_gemm
(
gridwise_gemm_
)
{}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{},
int
nrepeat
=
1
)
{
if
(
!
gridwise_gemm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmAvx2_MxN has invalid setting"
);
}
// memset(arg.p_c_grid_, 0, arg.c_grid_desc_.GetElementSpaceSize());
const
auto
kernel
=
ck
::
cpu
::
kernel_direct_conv_nhwc_bias_activation_add_avx_mxn
<
GridwiseGemm
,
ADataType
,
BDataType
,
CDataType
,
C0DataType
,
C1DataType
,
AGridDesc
,
BGridDesc
,
CGridDesc
,
C0GridDesc
,
C1GridDesc
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
float
ave_time
=
0
;
if
(
nrepeat
!=
1
)
ave_time
=
launch_and_time_cpu_kernel
(
kernel
,
nrepeat
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c0_grid_
,
arg
.
p_c1_grid_
,
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
,
arg
.
c0_grid_desc_
,
arg
.
c1_grid_desc_
,
arg
.
Conv_N_
,
arg
.
Conv_K_
,
arg
.
Conv_C_
,
arg
.
input_spatial_lengths_
,
arg
.
filter_spatial_lengths_
,
arg
.
output_spatial_lengths_
,
arg
.
conv_filter_strides_
,
arg
.
conv_filter_dilations_
,
arg
.
input_left_pads_
,
arg
.
input_right_pads_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
// TODO: this is for benchmark purpose, so last time we clear c buffer and calculate the
// result
// memset(arg.p_c_grid_, 0xfe, arg.c_grid_desc_.GetElementSpaceSize());
launch_cpu_kernel
(
kernel
,
gridwise_gemm
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c0_grid_
,
arg
.
p_c1_grid_
,
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
,
arg
.
c0_grid_desc_
,
arg
.
c1_grid_desc_
,
arg
.
Conv_N_
,
arg
.
Conv_K_
,
arg
.
Conv_C_
,
arg
.
input_spatial_lengths_
,
arg
.
filter_spatial_lengths_
,
arg
.
output_spatial_lengths_
,
arg
.
conv_filter_strides_
,
arg
.
conv_filter_dilations_
,
arg
.
input_left_pads_
,
arg
.
input_right_pads_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
return
ave_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{},
int
nrepeat
=
1
)
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
,
nrepeat
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
)
{
// check if it's 1x1, stride=1 conv
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
arg
.
conv_filter_strides_
[
0
]
==
1
&&
arg
.
conv_filter_strides_
[
1
]
==
1
&&
arg
.
input_left_pads_
[
0
]
==
0
&&
arg
.
input_left_pads_
[
1
]
==
0
&&
arg
.
input_right_pads_
[
0
]
==
0
&&
arg
.
input_right_pads_
[
1
]
==
0
))
{
return
false
;
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
)
{
// check if it's 1x1 conv
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
arg
.
input_left_pads_
[
0
]
==
0
&&
arg
.
input_left_pads_
[
1
]
==
0
&&
arg
.
input_right_pads_
[
0
]
==
0
&&
arg
.
input_right_pads_
[
1
]
==
0
))
{
return
false
;
}
}
// if(gridwise_gemm.dynamic_tunable.gemm_k_spec ==
// ConvolutionForwardGemmKSpecialization_t::NHWC_GemmKLoopOverC &&
// ConvForwardSpecialization != ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
// {
// if(!(arg.Conv_C_ % gridwise_gemm.dynamic_tunable.k_per_block == 0))
// return false;
// }
if
(
!
(
arg
.
Conv_K_
%
8
==
0
))
return
false
;
// if constexpr(!UseALocalBuffer &&
// ConvForwardSpecialization !=
// ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
// {
// // TODO: We can support this in the future, as long as figure out how to express
// tensor
// // transform
// return false;
// }
// Gridwise GEMM size
return
gridwise_gemm
.
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_
,
arg
.
c_grid_desc_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
const
WeiDataType
*
p_wei_grid
,
OutDataType
*
p_out_grid
,
const
BiasDataType
*
p_bias_grid
,
const
AddDataType
*
p_add_grid
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
{
return
Argument
{
p_in_grid
,
p_wei_grid
,
p_out_grid
,
p_bias_grid
,
p_add_grid
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
in_element_op
,
wei_element_op
,
out_element_op
};
}
auto
MakeInvoker
()
{
return
Invoker
{
gridwise_gemm
};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_grid
,
const
void
*
p_wei_grid
,
void
*
p_out_grid
,
const
void
*
p_bias_grid
,
const
void
*
p_add_grid
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
static_cast
<
const
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
OutDataType
*>
(
p_out_grid
),
static_cast
<
const
BiasDataType
*>
(
p_bias_grid
),
static_cast
<
const
AddDataType
*>
(
p_add_grid
),
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
in_element_op
,
wei_element_op
,
out_element_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{
gridwise_gemm
});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
auto
string_local_buffer
=
[](
bool
is_local_buffer
)
{
if
(
is_local_buffer
)
return
"L"
;
else
return
"G"
;
};
// clang-format off
str
<<
"DeviceConv"
<<
std
::
to_string
(
NumDimSpatial
)
<<
"DDirectFwd_BBAA_vx2_NHWC_KYXCK8"
// <<"_FS"<< static_cast<int>(ConvForwardSpecialization)
// <<"_KS"<< static_cast<int>(gridwise_gemm.dynamic_tunable.gemm_k_spec)
<<
"_BS"
<<
static_cast
<
int
>
(
gridwise_gemm
.
dynamic_tunable
.
loop_over_spec
)
// << "_BT" << gridwise_gemm.dynamic_tunable.m_per_block << "x" << gridwise_gemm.dynamic_tunable.n_per_block << "x" << gridwise_gemm.dynamic_tunable.k_per_block
<<
"_TT"
<<
MPerThread
<<
"x"
<<
NPerThread
<<
"_A"
<<
string_local_buffer
(
UseALocalBuffer
)
<<
"_B"
<<
string_local_buffer
(
UseBLocalBuffer
)
<<
"_C"
<<
string_local_buffer
(
UseCLocalBuffer
)
;
if
constexpr
(
!
std
::
is_same
<
OutElementwiseOperation
,
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
>::
value
)
{
str
<<
"_"
<<
OutElementwiseOperation
::
Name
();
}
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace cpu
}
// namespace tensor_operation
}
// namespace ck
#endif
include/ck/tensor_operation/cpu/grid/gridwise_direct_conv_bias_activation_add_avx2.hpp
0 → 100644
View file @
5db79de0
#ifndef CK_GRIDWISE_DIRECT_CONV_BIAS_ACTIVATION_ADD_AVX2_HPP
#define CK_GRIDWISE_DIRECT_CONV_BIAS_ACTIVATION_ADD_AVX2_HPP
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/cpu/block/blockwise_gemm_avx2.hpp"
#include "ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2.hpp"
#include "ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp"
#include "ck/utility/dynamic_buffer_cpu.hpp"
#include "ck/utility/envvar.hpp"
#include <utility>
#include <unistd.h>
#include <omp.h>
#include <pthread.h>
namespace
ck
{
namespace
cpu
{
template
<
typename
GridwiseDirectConv
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
FloatC0
,
typename
FloatC1
,
typename
AGridDesc
,
typename
BGridDesc
,
typename
CGridDesc
,
typename
C0GridDesc
,
typename
C1GridDesc
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
void
kernel_direct_conv_nhwc_bias_activation_add_avx_mxn
(
const
GridwiseDirectConv
&
gridwise_direct_conv
,
const
FloatA
*
__restrict__
p_a_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatC0
*
__restrict__
p_c0_grid
,
const
FloatC1
*
__restrict__
p_c1_grid
,
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
CGridDesc
&
c_grid_desc
,
const
C0GridDesc
&
c0_grid_desc
,
const
C1GridDesc
&
c1_grid_desc
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
)
{
gridwise_direct_conv
.
Run
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_c0_grid
,
p_c1_grid
,
a_grid_desc
,
b_grid_desc
,
c_grid_desc
,
c0_grid_desc
,
c1_grid_desc
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
a_element_op
,
b_element_op
,
c_element_op
);
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
FloatC0
,
typename
FloatC1
,
typename
AGridDesc
,
typename
BGridDesc
,
typename
CGridDesc
,
typename
C0GridDesc
,
typename
C1GridDesc
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
ThreadwiseGemm_Dispatch
,
typename
AThreadwiseCopy
,
typename
BThreadwiseCopy
,
typename
CThreadwiseCopy
,
typename
ThreadMNAccessOrder
,
// how we acces gemm MN to utilize micro kernel
bool
UseALocalBuffer
,
bool
UseBLocalBuffer
,
bool
UseCLocalBuffer
// if true, will allocate a buffer and write to it in kernel, then
// copy back to block buffer (need CThreadwiseCopy).
// if false, will write to C directly (no need CThreadwiseCopy)
>
struct
GridwiseDirectConvNHWCBiasActivationAddAvx2
{
ck
::
tensor_operation
::
cpu
::
device
::
DeviceConvFwdDynamicTunable
dynamic_tunable
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
// static constexpr auto Avx2RegisterVector = 8; // 8 floats
static
constexpr
index_t
MemAlignmentByte
=
32
;
// 256bit
GridwiseDirectConvNHWCBiasActivationAddAvx2
(
const
ck
::
tensor_operation
::
cpu
::
device
::
DeviceConvFwdDynamicTunable
dynamic_tunable_
)
:
dynamic_tunable
(
dynamic_tunable_
)
{
}
static
auto
GetABlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
k_per_blk
,
const
AGridDesc
&
a_grid_desc
)
{
if
constexpr
(
UseALocalBuffer
)
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// A : M, K
auto
a_block_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
m_per_blk
,
k_per_blk
));
return
a_block_desc_m_k
;
}
else
{
// A : K, M
auto
a_block_desc_k_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
k_per_blk
,
math
::
integer_least_multiple
(
m_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixAMinVectorSize
)));
return
a_block_desc_k_m
;
}
}
else
{
return
a_grid_desc
;
}
}
static
auto
GetBBlockDescriptor
(
const
ck
::
index_t
k_per_blk
,
const
ck
::
index_t
n_per_blk
,
const
BGridDesc
&
b_grid_desc
)
{
if
constexpr
(
UseBLocalBuffer
)
{
// n_per_blk should be 8x
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
auto
b_block_desc_k_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
k_per_blk
,
n_per_blk
));
return
b_block_desc_k_n
;
}
else
{
// B : N/8, K, N8
auto
b_block_desc_n0_k_n1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
math
::
integer_divide_ceil
(
n_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
k_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
return
b_block_desc_n0_k_n1
;
}
}
else
{
return
b_grid_desc
;
}
}
static
auto
GetCBlockDescriptor
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
n_per_blk
,
const
CGridDesc
&
c_grid_desc
)
{
if
constexpr
(
UseCLocalBuffer
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
m_per_blk
,
n_per_blk
));
}
else
return
c_grid_desc
;
}
static
auto
GetASliceLength
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
k_per_blk
)
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// A : M, K
return
ck
::
make_multi_index
(
m_per_blk
,
k_per_blk
);
}
else
{
// A : K, M
return
ck
::
make_multi_index
(
k_per_blk
,
math
::
integer_least_multiple
(
m_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixAMinVectorSize
));
}
}
static
auto
GetBSliceLength
(
const
ck
::
index_t
k_per_blk
,
const
ck
::
index_t
n_per_blk
)
{
// n_per_blk should be 8x
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
return
ck
::
make_multi_index
(
k_per_blk
,
math
::
integer_least_multiple
(
n_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
));
}
else
{
// B : N/8, K, N8
return
ck
::
make_multi_index
(
math
::
integer_divide_ceil
(
n_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
),
k_per_blk
,
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
}
}
static
auto
GetCSliceLength
(
const
ck
::
index_t
m_per_blk
,
const
ck
::
index_t
n_per_blk
)
{
return
ck
::
make_multi_index
(
m_per_blk
,
n_per_blk
);
}
static
auto
GetAIndex
(
const
ck
::
index_t
i_m
,
const
ck
::
index_t
i_k
)
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixALayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// A : M, K
return
ck
::
make_multi_index
(
i_m
,
i_k
);
}
else
{
// A : K, M
return
ck
::
make_multi_index
(
i_k
,
i_m
);
}
}
static
auto
GetBIndex
(
const
ck
::
index_t
i_k
,
const
ck
::
index_t
i_n
)
{
// i_n should be 8x
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// B : K, N
return
ck
::
make_multi_index
(
i_k
,
i_n
);
}
else
{
// B : N/8, K, N8
return
ck
::
make_multi_index
(
i_n
/
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
,
i_k
,
i_n
%
ThreadwiseGemm_Dispatch
::
MatrixBMinVectorSize
);
}
}
static
auto
GetCIndex
(
const
ck
::
index_t
i_m
,
const
ck
::
index_t
i_n
)
{
return
ck
::
make_multi_index
(
i_m
,
i_n
);
}
bool
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
CGridDesc
&
c_grid_desc
)
{
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
bool
is_valid
=
true
;
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
if
constexpr
(
UseCLocalBuffer
)
{
if
(
dynamic_tunable
.
loop_over_spec
==
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardBlockLoopOverSpecialization_t
::
LoopOver_MKN
&&
dynamic_tunable
.
n_per_block
<
GemmN
)
is_valid
&=
false
;
}
else
{
// TODO: need check c grid is simple transform?
if
(
GemmN
%
8
!=
0
)
is_valid
&=
false
;
}
return
is_valid
;
}
static
intptr_t
GetBBlockStartOffset
(
const
BGridDesc
&
b_grid_desc
,
const
intptr_t
i_k
,
const
intptr_t
i_n
)
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// K * N
return
i_n
;
}
else
{
// N/8 * K * 8
return
i_n
*
b_grid_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}]
+
i_k
*
8
;
}
}
static
intptr_t
GetCBlockStartOffset
(
const
CGridDesc
&
c_grid_desc
,
const
intptr_t
i_m
,
const
intptr_t
i_n
)
{
return
i_m
*
c_grid_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}]
+
i_n
;
}
static
intptr_t
GetBLeadingElement
(
const
BGridDesc
&
b_grid_desc
)
{
if
constexpr
(
std
::
is_same
<
typename
ThreadwiseGemm_Dispatch
::
MatrixBLayout
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
// K * N
return
b_grid_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
else
{
// N/8 * K * 8
return
b_grid_desc
.
GetLength
(
Number
<
1
>
{})
*
b_grid_desc
.
GetLength
(
Number
<
2
>
{});
}
}
static
intptr_t
GetCLeadingElement
(
const
CGridDesc
&
c_grid_desc
)
{
return
c_grid_desc
.
GetTransforms
()[
Number
<
0
>
{}].
GetUpperLengths
()[
Number
<
1
>
{}];
}
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatC0
*
__restrict__
p_c0_grid
,
const
FloatC1
*
__restrict__
p_c1_grid
,
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
CGridDesc
&
c_grid_desc
,
const
C0GridDesc
&
c0_grid_desc
,
const
C1GridDesc
&
c1_grid_desc
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
)
const
{
const
ck
::
index_t
m_per_thread
=
ThreadwiseGemm_Dispatch
::
ThreadMaxMr
;
const
ck
::
index_t
n_per_thread
=
ThreadwiseGemm_Dispatch
::
ThreadMaxNr
;
const
ck
::
index_t
k_per_thread
=
C
;
const
auto
GemmM
=
c_grid_desc
.
GetLength
(
I0
);
const
auto
GemmN
=
c_grid_desc
.
GetLength
(
I1
);
const
auto
GemmK
=
a_grid_desc
.
GetLength
(
I1
);
const
intptr_t
Hi
=
input_spatial_lengths
[
0
];
const
intptr_t
Wi
=
input_spatial_lengths
[
1
];
const
intptr_t
Ho
=
output_spatial_lengths
[
0
];
const
intptr_t
Wo
=
output_spatial_lengths
[
1
];
const
intptr_t
Y
=
filter_spatial_lengths
[
0
];
const
intptr_t
X
=
filter_spatial_lengths
[
1
];
const
intptr_t
Sy
=
conv_filter_strides
[
0
];
const
intptr_t
Sx
=
conv_filter_strides
[
1
];
const
intptr_t
Dy
=
conv_filter_dilations
[
0
];
const
intptr_t
Dx
=
conv_filter_dilations
[
1
];
const
intptr_t
Py
=
input_left_pads
[
0
];
const
intptr_t
Px
=
input_left_pads
[
1
];
const
intptr_t
X_Dx
=
X
*
Dx
;
// const index_t Y_Dy = Y * Dy;
// const index_t InRightPadH = input_right_pads[0];
// const index_t InRightPadW = input_right_pads[1];
constexpr
auto
a_block_copy_dim
=
AGridDesc
::
GetNumOfDimension
();
constexpr
auto
b_block_copy_dim
=
BGridDesc
::
GetNumOfDimension
();
auto
a_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
const_cast
<
FloatA
*>
(
p_a_grid
),
a_grid_desc
.
GetElementSpaceSize
());
auto
b_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
const_cast
<
FloatB
*>
(
p_b_grid
),
b_grid_desc
.
GetElementSpaceSize
());
auto
c_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
FloatC
*>
(
p_c_grid
),
c_grid_desc
.
GetElementSpaceSize
());
auto
c0_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
const
FloatC0
*>
(
p_c0_grid
),
c0_grid_desc
.
GetElementSpaceSize
());
auto
c1_grid_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
reinterpret_cast
<
const
FloatC1
*>
(
p_c1_grid
),
c1_grid_desc
.
GetElementSpaceSize
());
int
total_threads
=
omp_get_max_threads
();
if
(
total_threads
>
1
&&
ck
::
getenv_int
(
"CK_CPU_BIND_CORE"
,
0
)
!=
0
)
{
#pragma omp parallel
{
int
tid
=
omp_get_thread_num
();
cpu_set_t
set
;
CPU_ZERO
(
&
set
);
CPU_SET
(
tid
,
&
set
);
if
(
sched_setaffinity
(
0
,
sizeof
(
set
),
&
set
)
==
-
1
)
{
throw
std
::
runtime_error
(
"wrong! fail to set thread affinity"
);
}
}
}
auto
devide_thread
=
[](
ck
::
index_t
n_
,
ck
::
index_t
length_
,
ck
::
index_t
factor_
)
{
ck
::
index_t
t_
=
n_
;
while
(
t_
>
length_
&&
(
t_
%
factor_
==
0
))
{
t_
/=
factor_
;
}
return
t_
;
};
const
intptr_t
num_works_n
=
N
;
const
intptr_t
num_works_ho
=
Ho
;
// const intptr_t num_works_nho = N * Ho;
const
intptr_t
num_works_wo
=
math
::
integer_divide_ceil
(
Wo
,
m_per_thread
);
const
intptr_t
num_works_k
=
math
::
integer_divide_ceil
(
K
,
n_per_thread
);
auto
distribute_num_threads_n_ho_wo_k
=
[
&
](
ck
::
index_t
&
num_threads_n_
,
ck
::
index_t
&
num_threads_ho_
,
ck
::
index_t
&
num_threads_wo_
,
ck
::
index_t
&
num_threads_k_
)
{
// TODO: only consider multiply of 2 to divide threads
ck
::
index_t
num_threads
=
total_threads
;
num_threads_n_
=
devide_thread
(
num_threads
,
num_works_n
,
2
);
num_threads
=
num_threads
/
num_threads_n_
;
num_threads_ho_
=
devide_thread
(
num_threads
,
num_works_ho
,
2
);
num_threads
=
num_threads
/
num_threads_ho_
;
num_threads_wo_
=
devide_thread
(
num_threads
,
num_works_wo
,
2
);
num_threads
=
num_threads
/
num_threads_wo_
;
num_threads_k_
=
devide_thread
(
num_threads
,
num_works_k
,
2
);
// num_threads = num_threads / num_threads_k_;
};
ck
::
index_t
num_threads_n
;
ck
::
index_t
num_threads_ho
;
ck
::
index_t
num_threads_wo
;
ck
::
index_t
num_threads_k
;
distribute_num_threads_n_ho_wo_k
(
num_threads_n
,
num_threads_ho
,
num_threads_wo
,
num_threads_k
);
const
ck
::
index_t
num_works_n_per_thread
=
math
::
integer_divide_ceil
(
num_works_n
,
num_threads_n
);
const
ck
::
index_t
num_works_ho_per_thread
=
math
::
integer_divide_ceil
(
num_works_ho
,
num_threads_ho
);
const
ck
::
index_t
num_works_wo_per_thread
=
math
::
integer_divide_ceil
(
num_works_wo
,
num_threads_wo
);
const
ck
::
index_t
num_works_k_per_thread
=
math
::
integer_divide_ceil
(
num_works_k
,
num_threads_k
);
// printf("num_threads_nho:%d, num_threads_wo:%d, num_threads_k:%d |
// num_works_nho_per_thread:%d, num_works_wo_per_thread:%d, num_works_k_per_thread:%d\n",
// num_threads_nho, num_threads_wo, num_threads_k, num_works_nho_per_thread,
// num_works_wo_per_thread, num_works_k_per_thread); fflush(stdout);
if
((
X
-
1
)
*
Dx
+
1
<=
Px
||
(
Y
-
1
)
*
Dy
+
1
<=
Py
)
{
// padding zero case, outpout will have zero due to upsampling
// TODO: This is ugly and slow
ck
::
cpu
::
avx2_util
::
memset32_avx2
(
&
c_grid_buf
.
p_data_
[
0
],
0
,
N
*
Ho
*
Wo
*
K
);
// printf("___ clear\n");
}
if
(
dynamic_tunable
.
loop_over_spec
==
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardBlockLoopOverSpecialization_t
::
LoopOver_MNK
)
{
// only parallel in gemm m dim
#pragma omp parallel
{
DeviceAlignedMemCPU
a_block_mem
(
UseALocalBuffer
?
m_per_thread
*
k_per_thread
*
sizeof
(
FloatA
)
:
0
,
MemAlignmentByte
);
DeviceAlignedMemCPU
b_block_mem
(
UseBLocalBuffer
?
k_per_thread
*
n_per_thread
*
sizeof
(
FloatB
)
:
0
,
MemAlignmentByte
);
DeviceAlignedMemCPU
c_block_mem
(
UseCLocalBuffer
?
(
m_per_thread
*
n_per_thread
*
sizeof
(
FloatC
))
:
0
,
MemAlignmentByte
);
auto
a_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
UseALocalBuffer
?
reinterpret_cast
<
FloatA
*>
(
a_block_mem
.
mpDeviceBuf
)
:
const_cast
<
FloatA
*>
(
p_a_grid
),
UseALocalBuffer
?
a_block_mem
.
mMemSize
/
sizeof
(
FloatA
)
:
a_grid_desc
.
GetElementSpaceSize
());
auto
b_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
UseBLocalBuffer
?
reinterpret_cast
<
FloatB
*>
(
b_block_mem
.
mpDeviceBuf
)
:
const_cast
<
FloatB
*>
(
p_b_grid
),
UseBLocalBuffer
?
b_block_mem
.
mMemSize
/
sizeof
(
FloatB
)
:
b_grid_desc
.
GetElementSpaceSize
());
auto
c_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
UseCLocalBuffer
?
reinterpret_cast
<
FloatC
*>
(
c_block_mem
.
mpDeviceBuf
)
:
reinterpret_cast
<
FloatC
*>
(
p_c_grid
),
UseCLocalBuffer
?
c_block_mem
.
mMemSize
/
sizeof
(
FloatC
)
:
c_grid_desc
.
GetElementSpaceSize
());
ck
::
index_t
tid
=
omp_get_thread_num
();
const
ck
::
index_t
tid_n
=
tid
%
num_threads_n
;
tid
/=
num_threads_n
;
const
ck
::
index_t
tid_ho
=
tid
%
num_threads_ho
;
tid
/=
num_threads_ho
;
const
ck
::
index_t
tid_wo
=
tid
%
num_threads_wo
;
tid
/=
num_threads_wo
;
const
ck
::
index_t
tid_k
=
tid
;
ck
::
cpu
::
ThreadwiseGemmParam
param
;
// param.Kr = k_per_block;
param
.
lda
=
Sx
*
C
*
sizeof
(
FloatA
);
param
.
ldb
=
GetBLeadingElement
(
b_grid_desc
)
*
sizeof
(
FloatB
);
param
.
ldc
=
GetCLeadingElement
(
c_grid_desc
)
*
sizeof
(
FloatC
);
param
.
alpha
=
1.0
f
;
// TODO
param
.
Kr
=
C
;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
// ck::index_t i_nho = tid_nho * num_works_nho_per_thread;
// ck::index_t i_ho = i_nho % Ho;
// ck::index_t i_n = i_nho / Ho;
// auto accumulate_n_ho = [&]() {
// i_ho++;
// if(i_ho >= Wo)
// {
// i_ho = 0;
// i_n++;
// }
// };
for
(
intptr_t
i_n
=
tid_n
*
num_works_n_per_thread
;
(
i_n
<
(
tid_n
+
1
)
*
num_works_n_per_thread
)
&&
i_n
<
num_works_n
;
i_n
+=
1
)
{
for
(
intptr_t
i_ho
=
tid_ho
*
num_works_ho_per_thread
;
(
i_ho
<
(
tid_ho
+
1
)
*
num_works_ho_per_thread
)
&&
i_ho
<
num_works_ho
;
i_ho
+=
1
)
{
// for input
intptr_t
i_hi_no_y
=
i_ho
*
Sy
-
Py
;
for
(
intptr_t
i_wo
=
tid_wo
*
num_works_wo_per_thread
*
m_per_thread
;
i_wo
<
(
tid_wo
+
1
)
*
num_works_wo_per_thread
*
m_per_thread
&&
i_wo
<
Wo
;
i_wo
+=
m_per_thread
)
{
intptr_t
current_wo_size_no_dx
=
ck
::
math
::
min
(
Wo
-
i_wo
,
(
intptr_t
)
m_per_thread
);
intptr_t
i_wi_no_x
=
i_wo
*
Sx
-
Px
;
// printf("-- i_nho:%d, i_wo:%d, num_works_nho:%d,
// num_threads_nho:%d(Hi:%d,nWi:%d)\n",
// i_nho, i_wo, num_works_nho, num_threads_nho, Hi,
// Wi);fflush(stdout);
for
(
intptr_t
i_k
=
tid_k
*
num_works_k_per_thread
*
n_per_thread
;
i_k
<
(
tid_k
+
1
)
*
num_works_k_per_thread
*
n_per_thread
;
i_k
+=
n_per_thread
)
{
intptr_t
i_dx
=
0
;
intptr_t
i_dy
=
0
;
bool
accmulate_c
=
false
;
intptr_t
current_k_size
=
ck
::
math
::
min
(
K
-
i_k
,
(
intptr_t
)
n_per_thread
);
auto
accumulate_dy_dx
=
[
&
]()
{
i_dx
+=
Dx
;
if
(
i_dx
>=
X_Dx
)
{
i_dx
=
0
;
i_dy
+=
Dy
;
}
};
for
(
intptr_t
i_yxc
=
0
;
i_yxc
<
(
Y
*
X
*
C
);
i_yxc
+=
C
,
accumulate_dy_dx
())
{
intptr_t
current_i_wo
=
i_wo
;
intptr_t
i_hi
=
i_hi_no_y
+
i_dy
;
if
(
i_hi
<
0
||
i_hi
>=
Hi
)
continue
;
intptr_t
i_wi
=
i_wi_no_x
+
i_dx
;
intptr_t
current_wo_size
=
current_wo_size_no_dx
;
intptr_t
pad_wo_size
=
0
;
// when left pad, we may never have
// a chance to clear zero (like
// padding) we need to manually clear that
if
(
i_wi
<
0
)
{
intptr_t
wi_to_zero_length
=
-
i_wi
;
// keep this a possitive number
intptr_t
steps_wo_turn_possitive
=
(
wi_to_zero_length
+
Sx
-
1
)
/
Sx
;
// how many steps need to move wo, to let wi to be
// possitive
current_wo_size
-=
steps_wo_turn_possitive
;
if
(
current_wo_size
<=
0
)
continue
;
current_i_wo
+=
steps_wo_turn_possitive
;
if
(
!
accmulate_c
)
pad_wo_size
=
steps_wo_turn_possitive
;
// if already accumulating,
// no need to manually set
i_wi
+=
steps_wo_turn_possitive
*
Sx
;
// now i_wi will be a possitive number
}
if
(
i_wi
>=
Wi
)
continue
;
// shrink right wi/wo
if
((
i_wi
+
((
current_wo_size
-
1
)
*
Sx
))
>=
Wi
)
{
// printf(" ->[r] i_wi:%d, r:%d(%d), ", i_wi, i_wi +
// ((current_wo_size - 1) * Sx), current_wo_size);
current_wo_size
=
(
Wi
-
1
-
i_wi
)
/
Sx
+
1
;
// NOTE: this be careful why here
// should be compute like this.
if
(
current_wo_size
<=
0
)
continue
;
}
param
.
accmulate_c
=
accmulate_c
?
1
:
0
;
accmulate_c
=
true
;
intptr_t
current_input_offset
=
i_n
*
Hi
*
Wi
*
C
+
i_hi
*
Wi
*
C
+
i_wi
*
C
;
if
(
pad_wo_size
!=
0
)
{
for
(
intptr_t
i_wo_pad
=
0
;
i_wo_pad
<
pad_wo_size
;
i_wo_pad
++
)
{
const
intptr_t
offset_c
=
GetCBlockStartOffset
(
c_grid_desc
,
(
i_n
*
Ho
+
i_ho
)
*
Wo
+
i_wo_pad
,
i_k
);
// printf("pad_wo_size:%d, current_k_block_size:%d,
// clear offset_c:%d\n",
// pad_wo_size, current_k_size,
// offset_c);fflush(stdout);
ck
::
cpu
::
avx2_util
::
memset32_avx2
(
&
c_block_buf
.
p_data_
[
offset_c
],
0
,
current_k_size
);
}
}
const
intptr_t
offset_a
=
current_input_offset
;
const
intptr_t
offset_b
=
GetBBlockStartOffset
(
b_grid_desc
,
i_yxc
,
i_k
);
const
intptr_t
offset_c
=
GetCBlockStartOffset
(
c_grid_desc
,
(
i_n
*
Ho
+
i_ho
)
*
Wo
+
current_i_wo
,
i_k
);
// printf("offset_a:%lu, offset_b:%lu, offset_c:%lu, i_n:%d,
// i_hi:%d, i_wi:%d, i_dx:%d, i_dy:%d, i_k:%d, i_ho:%d, i_wo:%d,
// current_wo_size:%d, current_k_size:%d, i_nho:%d, lda:%d,
// ldb:%d, ldc:%d, acc:%d\n",
// offset_a, offset_b, offset_c, i_n, i_hi, i_wi, i_dx,
// i_dy, i_k, i_ho, current_i_wo, current_wo_size,
// current_k_size, i_nho, param.lda / sizeof(FloatA),
// param.ldb / sizeof(FloatB), param.ldc / sizeof(FloatC),
// param.accmulate_c); fflush(stdout);
param
.
p_a
=
&
a_block_buf
.
p_data_
[
offset_a
];
param
.
p_b
=
&
b_block_buf
.
p_data_
[
offset_b
];
param
.
p_c
=
&
c_block_buf
.
p_data_
[
offset_c
];
ThreadwiseGemm_Dispatch
::
Run
(
&
param
,
current_wo_size
,
current_k_size
);
}
}
}
// thread block wise fusion
for
(
intptr_t
i_wo
=
tid_wo
*
num_works_wo_per_thread
*
m_per_thread
;
i_wo
<
(
tid_wo
+
1
)
*
num_works_wo_per_thread
*
m_per_thread
&&
i_wo
<
Wo
;
i_wo
+=
1
)
{
const
intptr_t
n_size
=
ck
::
math
::
min
(
K
-
tid_k
*
num_works_k_per_thread
*
n_per_thread
,
num_works_k_per_thread
*
n_per_thread
);
if
constexpr
(
CThreadwiseCopy
::
FuseBias
&&
CThreadwiseCopy
::
FuseAdd
)
{
const
intptr_t
offset_c
=
GetCBlockStartOffset
(
c_grid_desc
,
(
i_n
*
Ho
+
i_ho
)
*
Wo
+
i_wo
,
0
);
const
intptr_t
offset_c0
=
0
;
avx2_util
::
memcpy32_avx2_with_extra_2src
(
&
c_block_buf
.
p_data_
[
offset_c
],
&
c_block_buf
.
p_data_
[
offset_c
],
&
c0_grid_buf
.
p_data_
[
offset_c0
],
&
c1_grid_buf
.
p_data_
[
offset_c
],
n_size
,
c_element_op
);
}
else
{
}
}
}
}
}
}
else
if
(
dynamic_tunable
.
loop_over_spec
==
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardBlockLoopOverSpecialization_t
::
LoopOver_MKN
)
{
// only parallel in gemm m dim
#pragma omp parallel
{
DeviceAlignedMemCPU
a_block_mem
(
UseALocalBuffer
?
m_per_thread
*
k_per_thread
*
sizeof
(
FloatA
)
:
0
,
MemAlignmentByte
);
DeviceAlignedMemCPU
b_block_mem
(
UseBLocalBuffer
?
k_per_thread
*
n_per_thread
*
sizeof
(
FloatB
)
:
0
,
MemAlignmentByte
);
DeviceAlignedMemCPU
c_block_mem
(
UseCLocalBuffer
?
(
m_per_thread
*
n_per_thread
*
sizeof
(
FloatC
))
:
0
,
MemAlignmentByte
);
auto
a_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
UseALocalBuffer
?
reinterpret_cast
<
FloatA
*>
(
a_block_mem
.
mpDeviceBuf
)
:
const_cast
<
FloatA
*>
(
p_a_grid
),
UseALocalBuffer
?
a_block_mem
.
mMemSize
/
sizeof
(
FloatA
)
:
a_grid_desc
.
GetElementSpaceSize
());
auto
b_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
UseBLocalBuffer
?
reinterpret_cast
<
FloatB
*>
(
b_block_mem
.
mpDeviceBuf
)
:
const_cast
<
FloatB
*>
(
p_b_grid
),
UseBLocalBuffer
?
b_block_mem
.
mMemSize
/
sizeof
(
FloatB
)
:
b_grid_desc
.
GetElementSpaceSize
());
auto
c_block_buf
=
ck
::
cpu
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
UseCLocalBuffer
?
reinterpret_cast
<
FloatC
*>
(
c_block_mem
.
mpDeviceBuf
)
:
reinterpret_cast
<
FloatC
*>
(
p_c_grid
),
UseCLocalBuffer
?
c_block_mem
.
mMemSize
/
sizeof
(
FloatC
)
:
c_grid_desc
.
GetElementSpaceSize
());
ck
::
cpu
::
ThreadwiseGemmParam
param
;
// param.Kr = k_per_block;
param
.
lda
=
Sx
*
C
*
sizeof
(
FloatA
);
param
.
ldb
=
GetBLeadingElement
(
b_grid_desc
)
*
sizeof
(
FloatB
);
param
.
ldc
=
GetCLeadingElement
(
c_grid_desc
)
*
sizeof
(
FloatC
);
param
.
alpha
=
1.0
f
;
// TODO
param
.
Kr
=
C
;
// ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h
// iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w
ck
::
index_t
tid
=
omp_get_thread_num
();
const
ck
::
index_t
tid_n
=
tid
%
num_threads_n
;
tid
/=
num_threads_n
;
const
ck
::
index_t
tid_ho
=
tid
%
num_threads_ho
;
tid
/=
num_threads_ho
;
const
ck
::
index_t
tid_wo
=
tid
%
num_threads_wo
;
tid
/=
num_threads_wo
;
const
ck
::
index_t
tid_k
=
tid
;
for
(
intptr_t
i_n
=
tid_n
*
num_works_n_per_thread
;
(
i_n
<
(
tid_n
+
1
)
*
num_works_n_per_thread
)
&&
i_n
<
num_works_n
;
i_n
+=
1
)
{
for
(
intptr_t
i_ho
=
tid_ho
*
num_works_ho_per_thread
;
(
i_ho
<
(
tid_ho
+
1
)
*
num_works_ho_per_thread
)
&&
i_ho
<
num_works_ho
;
i_ho
+=
1
)
{
// for input
intptr_t
i_hi_no_y
=
i_ho
*
Sy
-
Py
;
for
(
intptr_t
i_wo
=
tid_wo
*
num_works_wo_per_thread
*
m_per_thread
;
i_wo
<
(
tid_wo
+
1
)
*
num_works_wo_per_thread
*
m_per_thread
&&
i_wo
<
Wo
;
i_wo
+=
m_per_thread
)
{
intptr_t
current_wo_size_no_dx
=
ck
::
math
::
min
(
Wo
-
i_wo
,
(
intptr_t
)
m_per_thread
);
intptr_t
i_wi_no_x
=
i_wo
*
Sx
-
Px
;
intptr_t
i_dx
=
0
;
intptr_t
i_dy
=
0
;
bool
accmulate_c
=
false
;
// printf("-- [%d] i_n:%d, i_ho:%d, i_wo:%d, num_works_n:%d,
// num_threads_n:%d(Hi:%d, Wi:%d), current_wo_size_no_dx:%d,
// m_per_thread:%d\n",
// tid, i_n, i_ho, i_wo, num_works_n, num_threads_n, Hi, Wi,
// current_wo_size_no_dx, m_per_thread);fflush(stdout);
auto
accumulate_dy_dx
=
[
&
]()
{
i_dx
+=
Dx
;
if
(
i_dx
>=
X_Dx
)
{
i_dx
=
0
;
i_dy
+=
Dy
;
}
};
for
(
intptr_t
i_yxc
=
0
;
i_yxc
<
(
Y
*
X
*
C
);
i_yxc
+=
C
,
accumulate_dy_dx
())
{
intptr_t
current_i_wo
=
i_wo
;
intptr_t
i_hi
=
i_hi_no_y
+
i_dy
;
bool
run_pad_only
=
false
;
if
(
i_hi
<
0
||
i_hi
>=
Hi
)
continue
;
intptr_t
i_wi
=
i_wi_no_x
+
i_dx
;
intptr_t
current_wo_size
=
current_wo_size_no_dx
;
intptr_t
pad_wo_size
=
0
;
// when left pad, we may never have a
// chance to clear zero (like
// padding) we need to manually clear that
/* left corner shift
* when i_wi is negative, need shift i_wo to right to make i_wi
* possitive sx px i_wi steps_wo_turn_possitive
* 1 0
* 0, 1, 2.... 0 2 0 0, 2, 4... 0 1 1 -1,
* 0, 1.... 1 2 1 -1, 1, 3.... 1 2 2 -2, 0, 2... 1 2
* 3 -3, -1, 1... 2 3 1 -1, 2, 5... 1 3 2 -2,
* 1, 4.... 1 3 3 -3, 0, 3 1 3 4 -4,
* -1, 2... 2
*/
if
(
i_wi
<
0
)
{
intptr_t
wi_to_zero_length
=
-
i_wi
;
// keep this a possitive number
intptr_t
steps_wo_turn_possitive
=
(
wi_to_zero_length
+
Sx
-
1
)
/
Sx
;
// how many steps need to move wo, to let wi to be
// possitive
current_wo_size
-=
steps_wo_turn_possitive
;
// printf("--- current_wo_size:%d, i_wi:%d\n", current_wo_size,
// i_wi);
if
(
current_wo_size
<=
0
)
continue
;
current_i_wo
+=
steps_wo_turn_possitive
;
if
(
!
accmulate_c
)
pad_wo_size
=
steps_wo_turn_possitive
;
// if already accumulating, no
// need to manually set
i_wi
+=
steps_wo_turn_possitive
*
Sx
;
// now i_wi will be a possitive number
}
if
(
i_wi
>=
Wi
)
{
continue
;
}
// shrink right wi/wo
if
((
i_wi
+
((
current_wo_size
-
1
)
*
Sx
))
>=
Wi
)
{
// printf(" ->[r] i_wi:%d, r:%d(%d), ", i_wi, i_wi +
// ((current_wo_size - 1) * Sx), current_wo_size);
current_wo_size
=
(
Wi
-
1
-
i_wi
)
/
Sx
+
1
;
// NOTE: this be careful why here
// should be compute like this.
if
(
current_wo_size
<=
0
)
continue
;
}
param
.
accmulate_c
=
accmulate_c
?
1
:
0
;
accmulate_c
=
true
;
intptr_t
current_input_offset
=
i_n
*
Hi
*
Wi
*
C
+
i_hi
*
Wi
*
C
+
i_wi
*
C
;
if
(
pad_wo_size
!=
0
)
{
// manually clear zero. this may and only may need once along
// the gemm_k reduction
intptr_t
i_k
=
tid_k
*
num_works_k_per_thread
*
n_per_thread
;
intptr_t
current_k_block_size
=
ck
::
math
::
min
(
K
-
i_k
,
(
intptr_t
)
num_works_k_per_thread
*
n_per_thread
);
const
intptr_t
offset_c
=
GetCBlockStartOffset
(
c_grid_desc
,
(
i_n
*
Ho
+
i_ho
)
*
Wo
,
i_k
);
// printf("[%d] pad_wo_size:%d, current_k_block_size:%d,
// offset_c:%d, i_wo:%d\n",
// tid, pad_wo_size, current_k_block_size, offset_c,
// i_wo);fflush(stdout);
ck
::
cpu
::
avx2_util
::
memset32_avx2
(
&
c_block_buf
.
p_data_
[
offset_c
],
0
,
current_k_block_size
*
pad_wo_size
);
}
if
(
run_pad_only
)
continue
;
for
(
intptr_t
i_k
=
tid_k
*
num_works_k_per_thread
*
n_per_thread
;
i_k
<
(
tid_k
+
1
)
*
num_works_k_per_thread
*
n_per_thread
;
i_k
+=
n_per_thread
)
{
intptr_t
current_k_size
=
ck
::
math
::
min
(
K
-
i_k
,
(
intptr_t
)
n_per_thread
);
const
intptr_t
offset_a
=
current_input_offset
;
const
intptr_t
offset_b
=
GetBBlockStartOffset
(
b_grid_desc
,
i_yxc
,
i_k
);
const
intptr_t
offset_c
=
GetCBlockStartOffset
(
c_grid_desc
,
(
i_n
*
Ho
+
i_ho
)
*
Wo
+
current_i_wo
,
i_k
);
// printf("[%d] offset_a:%lu, offset_b:%lu, offset_c:%lu,
// i_n:%d, i_hi:%d, i_wi:%d, i_dx:%d, i_dy:%d, i_k:%d, i_ho:%d,
// i_wo:%d, current_wo_size:%d, i_n:%d, i_ho:%d, lda:%d,
// ldb:%d\n",
// tid, offset_a, offset_b, offset_c, i_n, i_hi, i_wi, i_dx,
// i_dy, i_k, i_ho, current_i_wo, current_wo_size, i_n,
// i_ho, param.lda / sizeof(FloatA), param.ldb /
// sizeof(FloatB)); fflush(stdout);
param
.
p_a
=
&
a_block_buf
.
p_data_
[
offset_a
];
param
.
p_b
=
&
b_block_buf
.
p_data_
[
offset_b
];
param
.
p_c
=
&
c_block_buf
.
p_data_
[
offset_c
];
ThreadwiseGemm_Dispatch
::
Run
(
&
param
,
current_wo_size
,
current_k_size
);
}
}
}
// thread block wise fusion
for
(
intptr_t
i_wo
=
tid_wo
*
num_works_wo_per_thread
*
m_per_thread
;
i_wo
<
(
tid_wo
+
1
)
*
num_works_wo_per_thread
*
m_per_thread
&&
i_wo
<
Wo
;
i_wo
+=
1
)
{
const
intptr_t
n_size
=
ck
::
math
::
min
(
K
-
tid_k
*
num_works_k_per_thread
*
n_per_thread
,
num_works_k_per_thread
*
n_per_thread
);
if
constexpr
(
CThreadwiseCopy
::
FuseBias
&&
CThreadwiseCopy
::
FuseAdd
)
{
const
intptr_t
offset_c
=
GetCBlockStartOffset
(
c_grid_desc
,
(
i_n
*
Ho
+
i_ho
)
*
Wo
+
i_wo
,
0
);
const
intptr_t
offset_c0
=
0
;
avx2_util
::
memcpy32_avx2_with_extra_2src
(
&
c_block_buf
.
p_data_
[
offset_c
],
&
c_block_buf
.
p_data_
[
offset_c
],
&
c0_grid_buf
.
p_data_
[
offset_c0
],
&
c1_grid_buf
.
p_data_
[
offset_c
],
n_size
,
c_element_op
);
}
else
{
}
}
}
}
}
}
}
};
}
// namespace cpu
}
// namespace ck
#endif
include/ck/tensor_operation/cpu/thread/threadwise_tensor_slice_transfer_avx2_specialization.hpp
View file @
5db79de0
...
@@ -1768,6 +1768,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_
...
@@ -1768,6 +1768,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_
static
constexpr
ck
::
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
static
constexpr
ck
::
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
static
constexpr
bool
FuseBias
=
true
;
static
constexpr
bool
FuseAdd
=
true
;
constexpr
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN
(
constexpr
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_Residual_MxN
(
const
SrcDesc
&
src_desc
,
const
SrcDesc
&
src_desc
,
const
Index
&
,
const
Index
&
,
...
@@ -2434,6 +2437,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN
...
@@ -2434,6 +2437,9 @@ struct ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN
static
constexpr
ck
::
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
static
constexpr
ck
::
index_t
nDim
=
SrcDesc
::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
static
constexpr
bool
FuseBias
=
true
;
static
constexpr
bool
FuseAdd
=
false
;
constexpr
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN
(
constexpr
ThreadwiseTensorSliceTransferAvx2Specialization_MatC_Store_Bias_MxN
(
const
SrcDesc
&
src_desc
,
const
SrcDesc
&
src_desc
,
const
Index
&
,
const
Index
&
,
...
...
library/src/tensor_operation_instance/cpu/conv2d_fwd_bias_activation_add/device_conv2d_direct_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk_instance.cpp
0 → 100644
View file @
5db79de0
#include <stdlib.h>
#include <utility>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/cpu/device/convolution_forward_specialization_cpu.hpp"
#include "ck/tensor_operation/cpu/device/device_convnd_direct_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk.hpp"
#include "ck/tensor_operation/cpu/element/element_wise_operation_cpu.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
cpu
{
namespace
device
{
namespace
device_conv2d_fwd_bias_activation_add_avx2_instance
{
using
InType
=
float
;
using
WeiType
=
float
;
using
OutType
=
float
;
using
AccType
=
float
;
using
InLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
// NHWC
using
WeiLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// KYXCK8
static
constexpr
bool
NonTemporalStore
=
false
;
using
PT
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
PassThrough
;
using
AddReluAdd
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddReluAdd
;
using
AddRelu
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddRelu
;
using
Add
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
Add
;
using
AddAddRelu
=
ck
::
tensor_operation
::
cpu
::
element_wise
::
AddAddRelu
;
static
constexpr
auto
ConvFwdDefault
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Default
;
static
constexpr
auto
ConvFwd1x1P0
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Pad0
;
static
constexpr
auto
ConvFwd1x1S1P0
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardSpecialization_t
::
Filter1x1Stride1Pad0
;
static
constexpr
auto
DefaultGemmKLoop
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
DefaultGemmKLoop
;
static
constexpr
auto
GemmKLoopOverC
=
ck
::
tensor_operation
::
cpu
::
device
::
ConvolutionForwardGemmKSpecialization_t
::
NHWC_GemmKLoopOverC
;
static
constexpr
auto
LoopOver_MNK
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MNK
;
static
constexpr
auto
LoopOver_MKN
=
ck
::
tensor_operation
::
cpu
::
device
::
LoopOver_MKN
;
void
add_device_conv2d_direct_fwd_bias_activation_add_avx2_nhwc_kyxck8_nhwk
(
std
::
vector
<
DeviceConvFwdBiasActivationAddPtr
<
PT
,
PT
,
AddReluAdd
>>&
instances
)
{
ck
::
tensor_operation
::
device
::
instance
::
add_device_operation_instances
(
instances
,
std
::
make_tuple
(
// clang-format off
DeviceConvNDDirectFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
PT
,
PT
,
AddReluAdd
,
ConvFwdDefault
,
2
,
6
,
16
,
false
,
false
,
false
,
true
,
true
,
false
>
({
0
,
0
,
0
,
DefaultGemmKLoop
,
LoopOver_MKN
}),
DeviceConvNDDirectFwdBiasActivationAddAvx2_Input_N_Hi_Wi_C_Weight_K_Y_X_C_K8_Output_N_Ho_Wo_K
<
float
,
float
,
float
,
float
,
float
,
PT
,
PT
,
AddReluAdd
,
ConvFwdDefault
,
2
,
6
,
16
,
false
,
false
,
false
,
true
,
true
,
false
>
({
0
,
0
,
0
,
DefaultGemmKLoop
,
LoopOver_MNK
})
// clang-format on
));
}
}
// namespace device_conv2d_fwd_bias_activation_add_avx2_instance
}
// namespace device
}
// namespace cpu
}
// namespace tensor_operation
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment