Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
308146e7
Commit
308146e7
authored
May 03, 2022
by
Jianfeng yan
Browse files
turned on other operations
parent
8e3c41a5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
264 additions
and
283 deletions
+264
-283
include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+156
-175
profiler/CMakeLists.txt
profiler/CMakeLists.txt
+31
-31
profiler/src/profiler.cpp
profiler/src/profiler.cpp
+77
-77
No files found.
include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
308146e7
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include "batched_gemm_util.hpp"
#include "device.hpp"
#include "device.hpp"
#include "device_base.hpp"
#include "device_base.hpp"
#include "device_conv_backward_weight.hpp"
#include "device_conv_backward_weight.hpp"
...
@@ -13,7 +12,6 @@
...
@@ -13,7 +12,6 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r4r2.hpp"
#include "gridwise_gemm_xdlops_v2r4r2.hpp"
#include "batched_gemm_util.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -27,34 +25,33 @@ template <typename InDataType,
...
@@ -27,34 +25,33 @@ template <typename InDataType,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
typename
OutElementwiseOperation
,
index_t
NumGemmKPrefetchStage
,
ck
::
index_t
BlockSize
,
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
index_t
K0PerBlock
,
ck
::
index_t
K1
,
index_t
AK1
,
ck
::
index_t
MPerXdl
,
ck
::
index_t
MPerXdl
,
ck
::
index_t
NPerXdl
,
ck
::
index_t
NPerXdl
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_
A
K0_M_
A
K1
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_
A
K1
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsExtraM
,
bool
ABlockLds
Add
ExtraM
,
typename
BBlockTransferThreadClusterLengths_
B
K0_N_
B
K1
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_
B
K1
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsExtraN
,
bool
BBlockLds
Add
ExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
C
Shuffle
BlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
C
Shuffle
BlockTransferScalarPerVector_N
PerBlock
>
index_t
CBlockTransferScalarPerVector_N
WaveNPerXdl
>
struct
DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
struct
DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
:
public
DeviceConvBwdWeight
<
InElementwiseOperation
,
:
public
DeviceConvBwdWeight
<
InElementwiseOperation
,
WeiElementwiseOperation
,
WeiElementwiseOperation
,
...
@@ -95,7 +92,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -95,7 +92,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
ck
::
index_t
k_
batch
)
ck
::
index_t
batch
_k
)
{
{
using
namespace
ck
;
using
namespace
ck
;
...
@@ -120,40 +117,35 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -120,40 +117,35 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
const
index_t
GemmKTotal
=
N
*
Ho
*
Wo
;
const
index_t
GemmKTotal
=
N
*
Ho
*
Wo
;
const
index_t
GemmM
=
K
;
const
index_t
GemmM
=
K
;
const
index_t
GemmN
=
C
*
X
*
Y
;
const
index_t
GemmN
=
C
*
X
*
Y
;
const
index_t
GemmAKPerBatch
=
GemmAK0
*
GemmAK1Number
;
const
index_t
GemmBKPerBatch
=
GemmBK0
*
GemmBK1Number
;
const
index_t
GemmAK0
=
const
index_t
GemmKBatch
=
batch_k
;
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmAK1Number
*
K0PerBlock
*
k_batch
)
*
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1Number
*
K0PerBlock
*
GemmKBatch
)
*
K0PerBlock
;
K0PerBlock
;
const
index_t
GemmBK0
=
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1Number
;
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmBK1Number
*
K0PerBlock
*
k_batch
)
*
K0PerBlock
;
const
index_t
GemmAKPad
=
GemmKBatch
*
GemmAK0
*
GemmAK1Number
;
const
index_t
GemmBKPad
=
GemmKBatch
*
GemmBK0
*
GemmBK1Number
;
const
auto
out_gemmk_gemmm_grid_desc
=
const
auto
out_gemmk
total
_gemmm_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
));
const
auto
in_n_hi_wi_c_grid_desc
=
const
auto
in_n_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
// A: output tensor
// A: output tensor
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemmkpad_gemmm_grid_desc
=
transform_tensor_descriptor
(
out_gemmk_gemmm_grid_desc
,
out_gemmk
total
_gemmm_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
Gemm
A
KPad
-
GemmKTotal
),
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmM
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_
gemmkbatch_
gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_gemmkpad_gemmm_grid_desc
,
out_gemmkpad_gemmm_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Gemm
A
K0
,
Gemm
A
K1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Gemm
KBatch
,
Gemm
K0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmM
)),
make_pass_through_transform
(
GemmM
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// B: input tensor
// B: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
...
@@ -184,24 +176,24 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -184,24 +176,24 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmkpad_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_gemmktotal_gemmn_grid_desc
,
in_gemmktotal_gemmn_grid_desc
,
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
Gemm
B
KPad
-
GemmKTotal
),
make_tuple
(
make_right_pad_transform
(
GemmKTotal
,
GemmKPad
-
GemmKTotal
),
make_pass_through_transform
(
GemmN
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_
gemmkbatch_
gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
in_gemmkpad_gemmn_grid_desc
,
in_gemmkpad_gemmn_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Gemm
B
K0
,
Gemm
B
K1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Gemm
KBatch
,
Gemm
K0
,
GemmK1Number
)),
make_pass_through_transform
(
GemmN
)),
make_pass_through_transform
(
GemmN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
// C: weight tensor
// C: weight tensor
const
auto
wei_gemmm_gemmn_grid_desc
=
const
auto
wei_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
));
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
*
X
*
C
));
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
return
make_tuple
(
out_
gemmkbatch_
gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
in_
gemmkbatch_
gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
);
wei_gemmm_gemmn_grid_desc
);
}
}
...
@@ -213,97 +205,93 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -213,97 +205,93 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
AccDataType
,
CShuffleDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
KPerBlock
,
K0PerBlock
,
AK1
,
MPerXdl
,
BK1
,
NPerXdl
,
MPerXDL
,
K1
,
NPerXDL
,
MXdlPerWave
,
MXdlPerWave
,
NXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_
A
K0_M_
A
K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_
A
K1
,
ABlockTransferDstScalarPerVector_K1
,
false
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsExtraM
,
ABlockLds
Add
ExtraM
,
BBlockTransferThreadClusterLengths_
B
K0_N_
B
K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_
B
K1
,
BBlockTransferDstScalarPerVector_K1
,
false
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsExtraN
,
BBlockLds
Add
ExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CBlockTransferScalarPerVector_NWaveNPerXdl
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
;
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
;
CShuffleBlockTransferScalarPerVector_NPerBlock
>
;
using
GridwiseGemmAtomicAdd
=
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
using
GridwiseGemmAtomicAdd
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
AccDataType
,
CShuffleDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
::
AtomicAdd
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
AtomicAdd
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
KPerBlock
,
K0PerBlock
,
AK1
,
MPerXdl
,
BK1
,
NPerXdl
,
MPerXDL
,
K1
,
NPerXDL
,
MXdlPerWave
,
MXdlPerWave
,
NXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_
A
K0_M_
A
K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_
A
K1
,
ABlockTransferDstScalarPerVector_K1
,
false
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsExtraM
,
ABlockLds
Add
ExtraM
,
BBlockTransferThreadClusterLengths_
B
K0_N_
B
K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_
B
K1
,
BBlockTransferDstScalarPerVector_K1
,
false
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsExtraN
,
BBlockLds
Add
ExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CBlockTransferScalarPerVector_NWaveNPerXdl
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
;
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
;
CShuffleBlockTransferScalarPerVector_NPerBlock
>
;
// Argument
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
using
Block2CTileMap
=
decltype
(
BatchedGemmUtil
::
MakeBlock2CTileMap
<
MPerBlock
,
NPerBlock
>
(
1
,
1
,
1
));
decltype
(
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
CGridDesc_M_N
{},
1
,
1
,
1
));
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
const
InDataType
*
p_in_grid
,
Argument
(
const
InDataType
*
p_in_grid
,
...
@@ -328,8 +316,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -328,8 +316,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
:
p_a_grid_
{
p_out_grid
},
:
p_a_grid_
{
p_out_grid
},
p_b_grid_
{
p_in_grid
},
p_b_grid_
{
p_in_grid
},
p_c_grid_
{
p_wei_grid
},
p_c_grid_
{
p_wei_grid
},
a_grid_desc_k0_m_k1_
{},
a_grid_desc_
kbatch_
k0_m_k1_
{},
b_grid_desc_k0_n_k1_
{},
b_grid_desc_
kbatch_
k0_n_k1_
{},
c_grid_desc_m_n_
{},
c_grid_desc_m_n_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_ctile_map_
{},
block_2_ctile_map_
{},
...
@@ -361,36 +349,31 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -361,36 +349,31 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
input_right_pads
,
input_right_pads
,
k_batch_
);
k_batch_
);
a_grid_desc_k0_m_k1_
=
descs
[
I0
];
a_grid_desc_
kbatch_
k0_m_k1_
=
descs
[
I0
];
b_grid_desc_k0_n_k1_
=
descs
[
I1
];
b_grid_desc_
kbatch_
k0_n_k1_
=
descs
[
I1
];
c_grid_desc_m_n_
=
descs
[
I2
];
c_grid_desc_m_n_
=
descs
[
I2
];
if
(
GridwiseGemm
::
CheckValidity
(
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_kbatch_k0_m_k1_
,
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
b_grid_desc_kbatch_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
{
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
c_grid_desc_m_n_
);
block_2_ctile_map_
=
block_2_ctile_map_
=
BatchedGemmUtil
::
MakeBlock2CTileMap
<
MPerBlock
,
NPerBlock
>
(
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_grid_desc_m_n_
,
M01
,
N01
,
k_batch_
);
k_batch_
,
c_grid_desc_m_n_
.
GetLength
(
I0
),
c_grid_desc_m_n_
.
GetLength
(
I1
),
M01_
,
N01_
);
}
}
}
}
const
ADataType
*
p_a_grid_
;
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
index_t
k_batch_
;
AGridDesc_K0_M_K1
a_grid_desc_kbatch_k0_m_k1_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_kbatch_k0_n_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
BatchedGemmUtil
::
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
Block2CTileMap
block_2_ctile_map_
;
Block2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
M01_
;
index_t
N01_
;
index_t
N01_
;
...
@@ -416,14 +399,17 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -416,14 +399,17 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
void
ShowInfo
(
const
Argument
&
arg
)
void
ShowInfo
(
const
Argument
&
arg
)
{
{
std
::
cout
<<
"k_batch = "
<<
arg
.
BatchCount_
<<
"
\n
"
;
std
::
cout
<<
"arg.a_grid_desc_kbatch_k0_m_k1_{"
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_k0_n_k1_{"
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
std
::
cout
<<
"arg.b_grid_desc_kbatch_k0_n_k1_{"
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
...
@@ -432,8 +418,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -432,8 +418,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
{
ShowInfo
(
arg
);
ShowInfo
(
arg
);
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
kbatch_
k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_
kbatch_
k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
M01_
,
arg
.
N01_
))
arg
.
N01_
))
...
@@ -441,11 +427,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -441,11 +427,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"
);
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"
);
}
}
const
auto
k_batch
=
arg
.
k_batch_
;
const
auto
kbatch
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
);
const
index_t
grid_size
=
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
,
kbatch
);
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
k_batch
;
const
auto
K0
=
arg
.
a_grid_desc_
a
k0_m_
a
k1_
.
GetLength
(
I
0
);
const
auto
K0
=
arg
.
a_grid_desc_
kbatch_
k0_m_k1_
.
GetLength
(
I
1
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
...
@@ -463,8 +448,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -463,8 +448,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
a_grid_desc_
kbatch_
k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_
kbatch_
k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
...
@@ -472,7 +457,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -472,7 +457,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
}
}
if
(
k
_
batch
>
1
||
nrepeat
<=
0
)
if
(
kbatch
>
1
||
nrepeat
<=
0
)
{
{
hipGetErrorString
(
hipMemset
(
hipGetErrorString
(
hipMemset
(
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
...
@@ -487,8 +472,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -487,8 +472,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
a_grid_desc_
kbatch_
k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_
kbatch_
k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
...
@@ -499,38 +484,36 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -499,38 +484,36 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
if
(
has_main_k0_block_loop
)
if
(
has_main_k0_block_loop
)
{
{
if
(
k
_
batch
==
1
)
if
(
kbatch
==
1
)
{
{
const
auto
kernel
=
kernel_
batched_gemm_xdl_cshuffle_v1
<
const
auto
kernel
=
kernel_
gemm_xdlops_v2r4r2
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
AElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
BElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
CElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
OutElementwiseOperation
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
InElementwiseOperation
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
WeiElementwiseOperation
,
ComputePtrOffsetOfStridedBatch
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
Block2CTileMap
,
true
>
;
true
>
;
Run
(
kernel
);
Run
(
kernel
);
}
}
else
else
{
{
const
auto
kernel
=
kernel_
batched_gemm_xdl_cshuffle_v1
<
const
auto
kernel
=
kernel_
gemm_xdlops_v2r4r2
<
GridwiseGemmAtomicAdd
,
GridwiseGemmAtomicAdd
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
AElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
BElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
CElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
OutElementwiseOperation
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
InElementwiseOperation
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
WeiElementwiseOperation
,
ComputePtrOffsetOfStridedBatch
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
Block2CTileMap
,
true
>
;
true
>
;
Run
(
kernel
);
Run
(
kernel
);
...
@@ -538,38 +521,36 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -538,38 +521,36 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
}
}
else
else
{
{
if
(
k
_
batch
==
1
)
if
(
kbatch
==
1
)
{
{
const
auto
kernel
=
kernel_
batched_gemm_xdl_cshuffle_v1
<
const
auto
kernel
=
kernel_
gemm_xdlops_v2r4r2
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
AElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
BElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
CElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
OutElementwiseOperation
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
InElementwiseOperation
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
WeiElementwiseOperation
,
ComputePtrOffsetOfStridedBatch
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
Block2CTileMap
,
false
>
;
false
>
;
Run
(
kernel
);
Run
(
kernel
);
}
}
else
else
{
{
const
auto
kernel
=
kernel_
batched_gemm_xdl_cshuffle_v1
<
const
auto
kernel
=
kernel_
gemm_xdlops_v2r4r2
<
GridwiseGemmAtomicAdd
,
GridwiseGemmAtomicAdd
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
AElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
BElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
CElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
OutElementwiseOperation
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
InElementwiseOperation
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
WeiElementwiseOperation
,
ComputePtrOffsetOfStridedBatch
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
Block2CTileMap
,
false
>
;
false
>
;
Run
(
kernel
);
Run
(
kernel
);
...
@@ -602,14 +583,14 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
...
@@ -602,14 +583,14 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
}
}
// vector store C matrix into global memory
// vector store C matrix into global memory
if
(
!
(
arg
.
Conv_C_
%
C
Shuffle
BlockTransferScalarPerVector_N
PerBlock
==
0
))
if
(
!
(
arg
.
Conv_C_
%
CBlockTransferScalarPerVector_N
WaveNPerXdl
==
0
))
{
{
return
false
;
return
false
;
}
}
// Gridwise GEMM size
// Gridwise GEMM size
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
kbatch_
k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_
kbatch_
k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
M01_
,
arg
.
N01_
);
arg
.
N01_
);
...
...
profiler/CMakeLists.txt
View file @
308146e7
...
@@ -24,40 +24,40 @@ include_directories(BEFORE
...
@@ -24,40 +24,40 @@ include_directories(BEFORE
set
(
PROFILER_SOURCE
set
(
PROFILER_SOURCE
src/profiler.cpp
src/profiler.cpp
src/profile_gemm.cpp
src/profile_gemm.cpp
#
src/profile_gemm_bias_2d.cpp
src/profile_gemm_bias_2d.cpp
#
src/profile_gemm_bias_relu.cpp
src/profile_gemm_bias_relu.cpp
#
src/profile_gemm_bias_relu_add.cpp
src/profile_gemm_bias_relu_add.cpp
#
src/profile_gemm_reduce.cpp
src/profile_gemm_reduce.cpp
#
src/profile_batched_gemm.cpp
src/profile_batched_gemm.cpp
#
src/profile_conv_fwd_bias_relu.cpp
src/profile_conv_fwd_bias_relu.cpp
#
src/profile_conv_fwd_bias_relu_add.cpp
src/profile_conv_fwd_bias_relu_add.cpp
#
src/profile_conv_fwd_bias_relu_atomic_add.cpp
src/profile_conv_fwd_bias_relu_atomic_add.cpp
#
src/profile_convnd_fwd.cpp
src/profile_convnd_fwd.cpp
#
src/profile_convnd_bwd_data.cpp
src/profile_convnd_bwd_data.cpp
#
src/profile_reduce.cpp
src/profile_reduce.cpp
#
src/profile_grouped_gemm.cpp
src/profile_grouped_gemm.cpp
#
src/profile_conv_bwd_weight.cpp
src/profile_conv_bwd_weight.cpp
#
src/profile_batched_gemm_reduce.cpp
src/profile_batched_gemm_reduce.cpp
)
)
add_executable
(
ckProfiler
${
PROFILER_SOURCE
}
)
add_executable
(
ckProfiler
${
PROFILER_SOURCE
}
)
target_link_libraries
(
ckProfiler PRIVATE host_tensor
)
target_link_libraries
(
ckProfiler PRIVATE host_tensor
)
#
target_link_libraries(ckProfiler PRIVATE conv_fwd_util)
target_link_libraries
(
ckProfiler PRIVATE conv_fwd_util
)
#
target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_reduce_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_bias2d_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_bias_relu_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance)
target_link_libraries
(
ckProfiler PRIVATE device_gemm_bias_relu_add_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance)
target_link_libraries
(
ckProfiler PRIVATE device_batched_gemm_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_conv1d_fwd_instance)
target_link_libraries
(
ckProfiler PRIVATE device_conv1d_fwd_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance)
target_link_libraries
(
ckProfiler PRIVATE device_conv2d_fwd_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_conv3d_fwd_instance)
target_link_libraries
(
ckProfiler PRIVATE device_conv3d_fwd_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance)
target_link_libraries
(
ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance)
target_link_libraries
(
ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance)
target_link_libraries
(
ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance)
target_link_libraries
(
ckProfiler PRIVATE device_convnd_bwd_data_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
target_link_libraries
(
ckProfiler PRIVATE device_reduce_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
target_link_libraries
(
ckProfiler PRIVATE device_grouped_gemm_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance)
target_link_libraries
(
ckProfiler PRIVATE device_conv2d_bwd_weight_instance
)
#
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance)
target_link_libraries
(
ckProfiler PRIVATE device_batched_gemm_reduce_instance
)
profiler/src/profiler.cpp
View file @
308146e7
...
@@ -7,19 +7,19 @@
...
@@ -7,19 +7,19 @@
#include "profile_convnd_fwd.hpp"
#include "profile_convnd_fwd.hpp"
int
profile_gemm
(
int
,
char
*
[]);
int
profile_gemm
(
int
,
char
*
[]);
//
int profile_gemm_bias_2d(int, char*[]);
int
profile_gemm_bias_2d
(
int
,
char
*
[]);
//
int profile_gemm_bias_relu(int, char*[]);
int
profile_gemm_bias_relu
(
int
,
char
*
[]);
//
int profile_gemm_bias_relu_add(int, char*[]);
int
profile_gemm_bias_relu_add
(
int
,
char
*
[]);
//
int profile_gemm_reduce(int, char*[]);
int
profile_gemm_reduce
(
int
,
char
*
[]);
//
int profile_batched_gemm(int, char*[]);
int
profile_batched_gemm
(
int
,
char
*
[]);
//
int profile_grouped_gemm(int, char*[]);
int
profile_grouped_gemm
(
int
,
char
*
[]);
//
int profile_conv_fwd_bias_relu(int, char*[]);
int
profile_conv_fwd_bias_relu
(
int
,
char
*
[]);
//
int profile_conv_fwd_bias_relu_add(int, char*[]);
int
profile_conv_fwd_bias_relu_add
(
int
,
char
*
[]);
//
int profile_conv_fwd_bias_relu_atomic_add(int, char*[]);
int
profile_conv_fwd_bias_relu_atomic_add
(
int
,
char
*
[]);
//
int profile_convnd_bwd_data(int, char*[], int);
int
profile_convnd_bwd_data
(
int
,
char
*
[],
int
);
//
int profile_reduce(int, char*[]);
int
profile_reduce
(
int
,
char
*
[]);
//
int profile_conv_bwd_weight(int, char*[]);
int
profile_conv_bwd_weight
(
int
,
char
*
[]);
//
int profile_batched_gemm_reduce(int, char*[]);
int
profile_batched_gemm_reduce
(
int
,
char
*
[]);
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
...
@@ -27,70 +27,70 @@ int main(int argc, char* argv[])
...
@@ -27,70 +27,70 @@ int main(int argc, char* argv[])
{
{
return
profile_gemm
(
argc
,
argv
);
return
profile_gemm
(
argc
,
argv
);
}
}
//
else if(strcmp(argv[1], "gemm_bias_2d") == 0)
else
if
(
strcmp
(
argv
[
1
],
"gemm_bias_2d"
)
==
0
)
//
{
{
//
return profile_gemm_bias_2d(argc, argv);
return
profile_gemm_bias_2d
(
argc
,
argv
);
//
}
}
//
else if(strcmp(argv[1], "gemm_bias_relu") == 0)
else
if
(
strcmp
(
argv
[
1
],
"gemm_bias_relu"
)
==
0
)
//
{
{
//
return profile_gemm_bias_relu(argc, argv);
return
profile_gemm_bias_relu
(
argc
,
argv
);
//
}
}
//
else if(strcmp(argv[1], "gemm_bias_relu_add") == 0)
else
if
(
strcmp
(
argv
[
1
],
"gemm_bias_relu_add"
)
==
0
)
//
{
{
//
return profile_gemm_bias_relu_add(argc, argv);
return
profile_gemm_bias_relu_add
(
argc
,
argv
);
//
}
}
//
else if(strcmp(argv[1], "gemm_reduce") == 0)
else
if
(
strcmp
(
argv
[
1
],
"gemm_reduce"
)
==
0
)
//
{
{
//
return profile_gemm_reduce(argc, argv);
return
profile_gemm_reduce
(
argc
,
argv
);
//
}
}
//
else if(strcmp(argv[1], "batched_gemm") == 0)
else
if
(
strcmp
(
argv
[
1
],
"batched_gemm"
)
==
0
)
//
{
{
//
return profile_batched_gemm(argc, argv);
return
profile_batched_gemm
(
argc
,
argv
);
//
}
}
//
else if(strcmp(argv[1], "batched_gemm_reduce") == 0)
else
if
(
strcmp
(
argv
[
1
],
"batched_gemm_reduce"
)
==
0
)
//
{
{
//
return profile_batched_gemm_reduce(argc, argv);
return
profile_batched_gemm_reduce
(
argc
,
argv
);
//
}
}
//
else if(strcmp(argv[1], "grouped_gemm") == 0)
else
if
(
strcmp
(
argv
[
1
],
"grouped_gemm"
)
==
0
)
//
{
{
//
profile_grouped_gemm(argc, argv);
profile_grouped_gemm
(
argc
,
argv
);
//
}
}
//
else if(strcmp(argv[1], "conv_fwd") == 0)
else
if
(
strcmp
(
argv
[
1
],
"conv_fwd"
)
==
0
)
//
{
{
//
return ck::profiler::profile_convnd_fwd(argc, argv);
return
ck
::
profiler
::
profile_convnd_fwd
(
argc
,
argv
);
//
}
}
//
else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0)
else
if
(
strcmp
(
argv
[
1
],
"conv_fwd_bias_relu"
)
==
0
)
//
{
{
//
return profile_conv_fwd_bias_relu(argc, argv);
return
profile_conv_fwd_bias_relu
(
argc
,
argv
);
//
}
}
//
else if(strcmp(argv[1], "conv_fwd_bias_relu_add") == 0)
else
if
(
strcmp
(
argv
[
1
],
"conv_fwd_bias_relu_add"
)
==
0
)
//
{
{
//
return profile_conv_fwd_bias_relu_add(argc, argv);
return
profile_conv_fwd_bias_relu_add
(
argc
,
argv
);
//
}
}
//
else if(strcmp(argv[1], "conv_fwd_bias_relu_atomic_add") == 0)
else
if
(
strcmp
(
argv
[
1
],
"conv_fwd_bias_relu_atomic_add"
)
==
0
)
//
{
{
//
return profile_conv_fwd_bias_relu_atomic_add(argc, argv);
return
profile_conv_fwd_bias_relu_atomic_add
(
argc
,
argv
);
//
}
}
//
else if(strcmp(argv[1], "conv1d_bwd_data") == 0)
else
if
(
strcmp
(
argv
[
1
],
"conv1d_bwd_data"
)
==
0
)
//
{
{
//
return profile_convnd_bwd_data(argc, argv, 1);
return
profile_convnd_bwd_data
(
argc
,
argv
,
1
);
//
}
}
//
else if(strcmp(argv[1], "conv2d_bwd_data") == 0)
else
if
(
strcmp
(
argv
[
1
],
"conv2d_bwd_data"
)
==
0
)
//
{
{
//
return profile_convnd_bwd_data(argc, argv, 2);
return
profile_convnd_bwd_data
(
argc
,
argv
,
2
);
//
}
}
//
else if(strcmp(argv[1], "conv3d_bwd_data") == 0)
else
if
(
strcmp
(
argv
[
1
],
"conv3d_bwd_data"
)
==
0
)
//
{
{
//
return profile_convnd_bwd_data(argc, argv, 3);
return
profile_convnd_bwd_data
(
argc
,
argv
,
3
);
//
}
}
//
else if(strcmp(argv[1], "reduce") == 0)
else
if
(
strcmp
(
argv
[
1
],
"reduce"
)
==
0
)
//
{
{
//
return profile_reduce(argc, argv);
return
profile_reduce
(
argc
,
argv
);
//
}
}
//
else if(strcmp(argv[1], "conv2d_bwd_weight") == 0)
else
if
(
strcmp
(
argv
[
1
],
"conv2d_bwd_weight"
)
==
0
)
//
{
{
//
return profile_conv_bwd_weight(argc, argv);
return
profile_conv_bwd_weight
(
argc
,
argv
);
//
}
}
else
else
{
{
// clang-format off
// clang-format off
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment