Commit 308146e7 authored by Jianfeng yan's avatar Jianfeng yan
Browse files

turned on other operations

parent 8e3c41a5
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "batched_gemm_util.hpp"
#include "device.hpp" #include "device.hpp"
#include "device_base.hpp" #include "device_base.hpp"
#include "device_conv_backward_weight.hpp" #include "device_conv_backward_weight.hpp"
...@@ -13,7 +12,6 @@ ...@@ -13,7 +12,6 @@
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r4r2.hpp" #include "gridwise_gemm_xdlops_v2r4r2.hpp"
#include "batched_gemm_util.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -27,34 +25,33 @@ template <typename InDataType, ...@@ -27,34 +25,33 @@ template <typename InDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
index_t NumGemmKPrefetchStage, ck::index_t BlockSize,
index_t BlockSize, ck::index_t MPerBlock,
index_t MPerBlock, ck::index_t NPerBlock,
index_t NPerBlock, ck::index_t K0PerBlock,
index_t K0PerBlock, ck::index_t K1,
index_t AK1,
ck::index_t MPerXdl, ck::index_t MPerXdl,
ck::index_t NPerXdl, ck::index_t NPerXdl,
ck::index_t MXdlPerWave, ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave, ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1, typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim, ck::index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector, ck::index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1, ck::index_t ABlockTransferDstScalarPerVector_K1,
bool ABlockLdsExtraM, bool ABlockLdsAddExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1, typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder, typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim, ck::index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector, ck::index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1, ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsExtraN, bool BBlockLdsAddExtraN,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock> index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvBwdWeight<InElementwiseOperation, : public DeviceConvBwdWeight<InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
...@@ -95,7 +92,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -95,7 +92,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
ck::index_t k_batch) ck::index_t batch_k)
{ {
using namespace ck; using namespace ck;
...@@ -120,40 +117,35 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -120,40 +117,35 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
const index_t InRightPadH = input_right_pads[0]; const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1]; const index_t InRightPadW = input_right_pads[1];
const index_t GemmKTotal = N * Ho * Wo; const index_t GemmKTotal = N * Ho * Wo;
const index_t GemmM = K; const index_t GemmM = K;
const index_t GemmN = C * X * Y; const index_t GemmN = C * X * Y;
const index_t GemmAKPerBatch = GemmAK0 * GemmAK1Number;
const index_t GemmBKPerBatch = GemmBK0 * GemmBK1Number;
const index_t GemmAK0 = const index_t GemmKBatch = batch_k;
math::integer_divide_ceil(GemmKTotal, GemmAK1Number * K0PerBlock * k_batch) * const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock; K0PerBlock;
const index_t GemmBK0 = const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
math::integer_divide_ceil(GemmKTotal, GemmBK1Number * K0PerBlock * k_batch) *
K0PerBlock;
const index_t GemmAKPad = GemmKBatch * GemmAK0 * GemmAK1Number;
const index_t GemmBKPad = GemmKBatch * GemmBK0 * GemmBK1Number;
const auto out_gemmk_gemmm_grid_desc = const auto out_gemmktotal_gemmm_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto in_n_hi_wi_c_grid_desc = const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
// A: output tensor // A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_gemmk_gemmm_grid_desc, out_gemmktotal_gemmm_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmAKPad - GemmKTotal), make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmM)), make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmAK0, GemmAK1Number)), make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)), make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: input tensor // B: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
...@@ -184,24 +176,24 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -184,24 +176,24 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc, in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmBKPad - GemmKTotal), make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmN)), make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmBK0, GemmBK1Number)), make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)), make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor // C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmm_gemmn_grid_desc);
} }
...@@ -213,97 +205,93 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -213,97 +205,93 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>; using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, AccDataType,
CShuffleDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, K0PerBlock,
AK1, MPerXdl,
BK1, NPerXdl,
MPerXDL, K1,
NPerXDL,
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1, ABlockTransferDstScalarPerVector_K1,
false, false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsExtraM, ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1, BBlockTransferDstScalarPerVector_K1,
false, false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsExtraN, BBlockLdsAddExtraN,
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
CShuffleBlockTransferScalarPerVector_NPerBlock > ;
using GridwiseGemmAtomicAdd = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, AccDataType,
CShuffleDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
InMemoryDataOperationEnum::AtomicAdd,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, K0PerBlock,
AK1, MPerXdl,
BK1, NPerXdl,
MPerXDL, K1,
NPerXDL,
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1, ABlockTransferDstScalarPerVector_K1,
false, false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsExtraM, ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1, BBlockTransferDstScalarPerVector_K1,
false, false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsExtraN, BBlockLdsAddExtraN,
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
CShuffleBlockTransferScalarPerVector_NPerBlock > ; // Argument
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
using Block2CTileMap = using Block2CTileMap =
decltype(BatchedGemmUtil::MakeBlock2CTileMap<MPerBlock, NPerBlock>(1, 1, 1)); decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const InDataType* p_in_grid, Argument(const InDataType* p_in_grid,
...@@ -328,8 +316,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -328,8 +316,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
: p_a_grid_{p_out_grid}, : p_a_grid_{p_out_grid},
p_b_grid_{p_in_grid}, p_b_grid_{p_in_grid},
p_c_grid_{p_wei_grid}, p_c_grid_{p_wei_grid},
a_grid_desc_k0_m_k1_{}, a_grid_desc_kbatch_k0_m_k1_{},
b_grid_desc_k0_n_k1_{}, b_grid_desc_kbatch_k0_n_k1_{},
c_grid_desc_m_n_{}, c_grid_desc_m_n_{},
c_grid_desc_mblock_mperblock_nblock_nperblock_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{}, block_2_ctile_map_{},
...@@ -361,36 +349,31 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -361,36 +349,31 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
input_right_pads, input_right_pads,
k_batch_); k_batch_);
a_grid_desc_k0_m_k1_ = descs[I0]; a_grid_desc_kbatch_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1]; b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2]; c_grid_desc_m_n_ = descs[I2];
if(GridwiseGemm::CheckValidity( if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) b_grid_desc_kbatch_k0_n_k1_,
c_grid_desc_m_n_,
M01_,
N01_))
{ {
c_grid_desc_mblock_mperblock_nblock_nperblock_ = c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_);
c_grid_desc_m_n_);
block_2_ctile_map_ =
block_2_ctile_map_ = BatchedGemmUtil::MakeBlock2CTileMap<MPerBlock, NPerBlock>( GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
k_batch_,
c_grid_desc_m_n_.GetLength(I0),
c_grid_desc_m_n_.GetLength(I1),
M01_,
N01_);
} }
} }
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
index_t k_batch_; AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
c_grid_desc_mblock_mperblock_nblock_nperblock_;
BatchedGemmUtil::ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
...@@ -416,14 +399,17 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -416,14 +399,17 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
void ShowInfo(const Argument& arg) void ShowInfo(const Argument& arg)
{ {
std::cout << "k_batch = " << arg.BatchCount_ << "\n"; std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
...@@ -432,8 +418,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -432,8 +418,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, int nrepeat = 1)
{ {
ShowInfo(arg); ShowInfo(arg);
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
arg.M01_, arg.M01_,
arg.N01_)) arg.N01_))
...@@ -441,11 +427,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -441,11 +427,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
} }
const auto k_batch = arg.k_batch_; const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
const index_t grid_size = const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_, kbatch);
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * k_batch;
const auto K0 = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0); const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
...@@ -463,8 +448,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -463,8 +448,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
...@@ -472,7 +457,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -472,7 +457,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
if(k_batch > 1 || nrepeat <= 0) if(kbatch > 1 || nrepeat <= 0)
{ {
hipGetErrorString(hipMemset( hipGetErrorString(hipMemset(
arg.p_c_grid_, arg.p_c_grid_,
...@@ -487,8 +472,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -487,8 +472,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
...@@ -499,38 +484,36 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -499,38 +484,36 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
{ {
if(k_batch == 1) if(kbatch == 1)
{ {
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v1< const auto kernel = kernel_gemm_xdlops_v2r4r2<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
AElementwiseOperation, remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
BElementwiseOperation, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
CElementwiseOperation, remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
DeviceOp::AGridDesc_AK0_M_AK1, OutElementwiseOperation,
DeviceOp::BGridDesc_BK0_N_BK1, InElementwiseOperation,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, WeiElementwiseOperation,
ComputePtrOffsetOfStridedBatch, remove_reference_t<DeviceOp::Block2CTileMap>,
Block2CTileMap,
true>; true>;
Run(kernel); Run(kernel);
} }
else else
{ {
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v1< const auto kernel = kernel_gemm_xdlops_v2r4r2<
GridwiseGemmAtomicAdd, GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
AElementwiseOperation, remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
BElementwiseOperation, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
CElementwiseOperation, remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
DeviceOp::AGridDesc_AK0_M_AK1, OutElementwiseOperation,
DeviceOp::BGridDesc_BK0_N_BK1, InElementwiseOperation,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, WeiElementwiseOperation,
ComputePtrOffsetOfStridedBatch, remove_reference_t<DeviceOp::Block2CTileMap>,
Block2CTileMap,
true>; true>;
Run(kernel); Run(kernel);
...@@ -538,38 +521,36 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -538,38 +521,36 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
} }
else else
{ {
if(k_batch == 1) if(kbatch == 1)
{ {
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v1< const auto kernel = kernel_gemm_xdlops_v2r4r2<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
AElementwiseOperation, remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
BElementwiseOperation, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
CElementwiseOperation, remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
DeviceOp::AGridDesc_AK0_M_AK1, OutElementwiseOperation,
DeviceOp::BGridDesc_BK0_N_BK1, InElementwiseOperation,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, WeiElementwiseOperation,
ComputePtrOffsetOfStridedBatch, remove_reference_t<DeviceOp::Block2CTileMap>,
Block2CTileMap,
false>; false>;
Run(kernel); Run(kernel);
} }
else else
{ {
const auto kernel = kernel_batched_gemm_xdl_cshuffle_v1< const auto kernel = kernel_gemm_xdlops_v2r4r2<
GridwiseGemmAtomicAdd, GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
AElementwiseOperation, remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
BElementwiseOperation, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
CElementwiseOperation, remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
DeviceOp::AGridDesc_AK0_M_AK1, OutElementwiseOperation,
DeviceOp::BGridDesc_BK0_N_BK1, InElementwiseOperation,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, WeiElementwiseOperation,
ComputePtrOffsetOfStridedBatch, remove_reference_t<DeviceOp::Block2CTileMap>,
Block2CTileMap,
false>; false>;
Run(kernel); Run(kernel);
...@@ -602,14 +583,14 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -602,14 +583,14 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
} }
// vector store C matrix into global memory // vector store C matrix into global memory
if(!(arg.Conv_C_ % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0))
{ {
return false; return false;
} }
// Gridwise GEMM size // Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
arg.M01_, arg.M01_,
arg.N01_); arg.N01_);
......
...@@ -24,40 +24,40 @@ include_directories(BEFORE ...@@ -24,40 +24,40 @@ include_directories(BEFORE
set(PROFILER_SOURCE set(PROFILER_SOURCE
src/profiler.cpp src/profiler.cpp
src/profile_gemm.cpp src/profile_gemm.cpp
# src/profile_gemm_bias_2d.cpp src/profile_gemm_bias_2d.cpp
# src/profile_gemm_bias_relu.cpp src/profile_gemm_bias_relu.cpp
# src/profile_gemm_bias_relu_add.cpp src/profile_gemm_bias_relu_add.cpp
# src/profile_gemm_reduce.cpp src/profile_gemm_reduce.cpp
# src/profile_batched_gemm.cpp src/profile_batched_gemm.cpp
# src/profile_conv_fwd_bias_relu.cpp src/profile_conv_fwd_bias_relu.cpp
# src/profile_conv_fwd_bias_relu_add.cpp src/profile_conv_fwd_bias_relu_add.cpp
# src/profile_conv_fwd_bias_relu_atomic_add.cpp src/profile_conv_fwd_bias_relu_atomic_add.cpp
# src/profile_convnd_fwd.cpp src/profile_convnd_fwd.cpp
# src/profile_convnd_bwd_data.cpp src/profile_convnd_bwd_data.cpp
# src/profile_reduce.cpp src/profile_reduce.cpp
# src/profile_grouped_gemm.cpp src/profile_grouped_gemm.cpp
# src/profile_conv_bwd_weight.cpp src/profile_conv_bwd_weight.cpp
# src/profile_batched_gemm_reduce.cpp src/profile_batched_gemm_reduce.cpp
) )
add_executable(ckProfiler ${PROFILER_SOURCE}) add_executable(ckProfiler ${PROFILER_SOURCE})
target_link_libraries(ckProfiler PRIVATE host_tensor) target_link_libraries(ckProfiler PRIVATE host_tensor)
# target_link_libraries(ckProfiler PRIVATE conv_fwd_util) target_link_libraries(ckProfiler PRIVATE conv_fwd_util)
# target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_instance)
# target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance)
# target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance)
# target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance)
# target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance) target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance)
# target_link_libraries(ckProfiler PRIVATE device_conv1d_fwd_instance) target_link_libraries(ckProfiler PRIVATE device_conv1d_fwd_instance)
# target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance)
# target_link_libraries(ckProfiler PRIVATE device_conv3d_fwd_instance) target_link_libraries(ckProfiler PRIVATE device_conv3d_fwd_instance)
# target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance)
# target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance)
# target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance)
# target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance) target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance)
# target_link_libraries(ckProfiler PRIVATE device_reduce_instance) target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
# target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance) target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
# target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance)
# target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance) target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance)
...@@ -7,19 +7,19 @@ ...@@ -7,19 +7,19 @@
#include "profile_convnd_fwd.hpp" #include "profile_convnd_fwd.hpp"
int profile_gemm(int, char*[]); int profile_gemm(int, char*[]);
// int profile_gemm_bias_2d(int, char*[]); int profile_gemm_bias_2d(int, char*[]);
// int profile_gemm_bias_relu(int, char*[]); int profile_gemm_bias_relu(int, char*[]);
// int profile_gemm_bias_relu_add(int, char*[]); int profile_gemm_bias_relu_add(int, char*[]);
// int profile_gemm_reduce(int, char*[]); int profile_gemm_reduce(int, char*[]);
// int profile_batched_gemm(int, char*[]); int profile_batched_gemm(int, char*[]);
// int profile_grouped_gemm(int, char*[]); int profile_grouped_gemm(int, char*[]);
// int profile_conv_fwd_bias_relu(int, char*[]); int profile_conv_fwd_bias_relu(int, char*[]);
// int profile_conv_fwd_bias_relu_add(int, char*[]); int profile_conv_fwd_bias_relu_add(int, char*[]);
// int profile_conv_fwd_bias_relu_atomic_add(int, char*[]); int profile_conv_fwd_bias_relu_atomic_add(int, char*[]);
// int profile_convnd_bwd_data(int, char*[], int); int profile_convnd_bwd_data(int, char*[], int);
// int profile_reduce(int, char*[]); int profile_reduce(int, char*[]);
// int profile_conv_bwd_weight(int, char*[]); int profile_conv_bwd_weight(int, char*[]);
// int profile_batched_gemm_reduce(int, char*[]); int profile_batched_gemm_reduce(int, char*[]);
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -27,70 +27,70 @@ int main(int argc, char* argv[]) ...@@ -27,70 +27,70 @@ int main(int argc, char* argv[])
{ {
return profile_gemm(argc, argv); return profile_gemm(argc, argv);
} }
// else if(strcmp(argv[1], "gemm_bias_2d") == 0) else if(strcmp(argv[1], "gemm_bias_2d") == 0)
// { {
// return profile_gemm_bias_2d(argc, argv); return profile_gemm_bias_2d(argc, argv);
// } }
// else if(strcmp(argv[1], "gemm_bias_relu") == 0) else if(strcmp(argv[1], "gemm_bias_relu") == 0)
// { {
// return profile_gemm_bias_relu(argc, argv); return profile_gemm_bias_relu(argc, argv);
// } }
// else if(strcmp(argv[1], "gemm_bias_relu_add") == 0) else if(strcmp(argv[1], "gemm_bias_relu_add") == 0)
// { {
// return profile_gemm_bias_relu_add(argc, argv); return profile_gemm_bias_relu_add(argc, argv);
// } }
// else if(strcmp(argv[1], "gemm_reduce") == 0) else if(strcmp(argv[1], "gemm_reduce") == 0)
// { {
// return profile_gemm_reduce(argc, argv); return profile_gemm_reduce(argc, argv);
// } }
// else if(strcmp(argv[1], "batched_gemm") == 0) else if(strcmp(argv[1], "batched_gemm") == 0)
// { {
// return profile_batched_gemm(argc, argv); return profile_batched_gemm(argc, argv);
// } }
// else if(strcmp(argv[1], "batched_gemm_reduce") == 0) else if(strcmp(argv[1], "batched_gemm_reduce") == 0)
// { {
// return profile_batched_gemm_reduce(argc, argv); return profile_batched_gemm_reduce(argc, argv);
// } }
// else if(strcmp(argv[1], "grouped_gemm") == 0) else if(strcmp(argv[1], "grouped_gemm") == 0)
// { {
// profile_grouped_gemm(argc, argv); profile_grouped_gemm(argc, argv);
// } }
// else if(strcmp(argv[1], "conv_fwd") == 0) else if(strcmp(argv[1], "conv_fwd") == 0)
// { {
// return ck::profiler::profile_convnd_fwd(argc, argv); return ck::profiler::profile_convnd_fwd(argc, argv);
// } }
// else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0) else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0)
// { {
// return profile_conv_fwd_bias_relu(argc, argv); return profile_conv_fwd_bias_relu(argc, argv);
// } }
// else if(strcmp(argv[1], "conv_fwd_bias_relu_add") == 0) else if(strcmp(argv[1], "conv_fwd_bias_relu_add") == 0)
// { {
// return profile_conv_fwd_bias_relu_add(argc, argv); return profile_conv_fwd_bias_relu_add(argc, argv);
// } }
// else if(strcmp(argv[1], "conv_fwd_bias_relu_atomic_add") == 0) else if(strcmp(argv[1], "conv_fwd_bias_relu_atomic_add") == 0)
// { {
// return profile_conv_fwd_bias_relu_atomic_add(argc, argv); return profile_conv_fwd_bias_relu_atomic_add(argc, argv);
// } }
// else if(strcmp(argv[1], "conv1d_bwd_data") == 0) else if(strcmp(argv[1], "conv1d_bwd_data") == 0)
// { {
// return profile_convnd_bwd_data(argc, argv, 1); return profile_convnd_bwd_data(argc, argv, 1);
// } }
// else if(strcmp(argv[1], "conv2d_bwd_data") == 0) else if(strcmp(argv[1], "conv2d_bwd_data") == 0)
// { {
// return profile_convnd_bwd_data(argc, argv, 2); return profile_convnd_bwd_data(argc, argv, 2);
// } }
// else if(strcmp(argv[1], "conv3d_bwd_data") == 0) else if(strcmp(argv[1], "conv3d_bwd_data") == 0)
// { {
// return profile_convnd_bwd_data(argc, argv, 3); return profile_convnd_bwd_data(argc, argv, 3);
// } }
// else if(strcmp(argv[1], "reduce") == 0) else if(strcmp(argv[1], "reduce") == 0)
// { {
// return profile_reduce(argc, argv); return profile_reduce(argc, argv);
// } }
// else if(strcmp(argv[1], "conv2d_bwd_weight") == 0) else if(strcmp(argv[1], "conv2d_bwd_weight") == 0)
// { {
// return profile_conv_bwd_weight(argc, argv); return profile_conv_bwd_weight(argc, argv);
// } }
else else
{ {
// clang-format off // clang-format off
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment