Commit e739c577 authored by Jianfeng yan's avatar Jianfeng yan
Browse files

refactor DeviceGemmXdlSplitK to support arbitrary K; remove template parameter...

refactor DeviceGemmXdlSplitK to support arbitrary K; remove template parameter A/B/CGridDesc from gridwise_gemm_v2r3; start rewriting conv2d_backward
parent 22d63c05
#pragma once
#include "tuple.hpp"
#include "tensor_adaptor.hpp"
#include "multi_index_transform_helper.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
struct BatchedGemmUtil
{
template <index_t MPerBlock, index_t NPerBlock>
static constexpr auto
MakeBlock2CTileMap(index_t batch_count, index_t M, index_t N, index_t M01=1, index_t N01=1)
{
constexpr auto M1 = MPerBlock;
constexpr auto N1 = NPerBlock;
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_insert_transform(batch_count),
make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(batch_count, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto globalblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return globalblockid_to_m0_n0_block_cluster_adaptor;
}
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideC)
: BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC)
{
}
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_);
}
private:
index_t BatchStrideA_;
index_t BatchStrideB_;
index_t BatchStrideC_;
};
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -13,6 +13,7 @@
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdl_cshuffle_v1.hpp"
#include "gemm_specialization.hpp"
#include "batched_gemm_util.hpp"
namespace ck {
namespace tensor_operation {
......@@ -343,43 +344,6 @@ struct DeviceGemmXdlSplitKCShuffle
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
static constexpr auto MakeBlock2CTileMap(index_t batch_count,
const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01,
index_t N01)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto M00 = M0 / M01;
const auto N00 = N0 / N01;
const auto g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_insert_transform(batch_count),
make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
const auto globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(batch_count, M00, N00, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto globalblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return globalblockid_to_m0_n0_block_cluster_adaptor;
}
struct ComputePtrOffsetOfStridedBatch
{
......@@ -455,7 +419,7 @@ struct DeviceGemmXdlSplitKCShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock>;
using Block2CTileMap = decltype(MakeBlock2CTileMap(1, CGridDesc_M_N{}, 1, 1));
using Block2CTileMap = decltype(BatchedGemmUtil::MakeBlock2CTileMap<MPerBlock, NPerBlock>(1, 1, 1));
struct Argument : public BaseArgument
{
......@@ -529,7 +493,7 @@ struct DeviceGemmXdlSplitKCShuffle
compute_ptr_offset_of_batch_ =
ComputePtrOffsetOfStridedBatch{a_batch_stride, b_batch_stride};
block_2_ctile_map_ = MakeBlock2CTileMap(BatchCount_, c_grid_desc_m_n_, 1, 1);
block_2_ctile_map_ = BatchedGemmUtil::MakeBlock2CTileMap<MPerBlock, NPerBlock>(BatchCount_, c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1));
}
}
......
......@@ -250,9 +250,6 @@ template <index_t BlockSize,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
......@@ -361,6 +358,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename AGridDesc_K0_M_K1, typename BGridDesc_K0_N_K1, typename CGridDesc_M_N>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
......@@ -420,9 +418,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return true;
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
static_assert(CGridDesc_M_N::GetNumOfVisibleDimension() == 2);
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
......@@ -439,9 +439,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return has_main_k0_block_loop;
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{
static_assert(CGridDesc_M_N::GetNumOfVisibleDimension() == 2);
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
......@@ -491,11 +494,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01)
MakeDefaultBlock2CTileMap(index_t M, index_t N, index_t M01, index_t N01)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
......@@ -525,11 +525,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return cblockid_to_m0_n0_block_cluster_adaptor;
}
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
// using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
// decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(1, 1, 1, 1));
template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
template <bool HasMainK0BlockLoop,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
typename Block2CTileMap = DefaultBlock2CTileMap>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
......
# device_gemm_instance
set(DEVICE_GEMM_INSTANCE_SOURCE
device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp;
device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp;
device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp;
device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp;
# device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp;
# device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp;
# device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp;
# device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp;
# device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp;
# device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp;
# device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp;
# device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp;
device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp;
device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp;
device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp;
# device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp;
# device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp;
# device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp;
# device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp;
)
add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE})
......
......@@ -24,40 +24,40 @@ include_directories(BEFORE
set(PROFILER_SOURCE
src/profiler.cpp
src/profile_gemm.cpp
src/profile_gemm_bias_2d.cpp
src/profile_gemm_bias_relu.cpp
src/profile_gemm_bias_relu_add.cpp
src/profile_gemm_reduce.cpp
src/profile_batched_gemm.cpp
src/profile_conv_fwd_bias_relu.cpp
src/profile_conv_fwd_bias_relu_add.cpp
src/profile_conv_fwd_bias_relu_atomic_add.cpp
src/profile_convnd_fwd.cpp
src/profile_convnd_bwd_data.cpp
src/profile_reduce.cpp
src/profile_grouped_gemm.cpp
src/profile_conv_bwd_weight.cpp
src/profile_batched_gemm_reduce.cpp
# src/profile_gemm_bias_2d.cpp
# src/profile_gemm_bias_relu.cpp
# src/profile_gemm_bias_relu_add.cpp
# src/profile_gemm_reduce.cpp
# src/profile_batched_gemm.cpp
# src/profile_conv_fwd_bias_relu.cpp
# src/profile_conv_fwd_bias_relu_add.cpp
# src/profile_conv_fwd_bias_relu_atomic_add.cpp
# src/profile_convnd_fwd.cpp
# src/profile_convnd_bwd_data.cpp
# src/profile_reduce.cpp
# src/profile_grouped_gemm.cpp
# src/profile_conv_bwd_weight.cpp
# src/profile_batched_gemm_reduce.cpp
)
add_executable(ckProfiler ${PROFILER_SOURCE})
target_link_libraries(ckProfiler PRIVATE host_tensor)
target_link_libraries(ckProfiler PRIVATE conv_fwd_util)
target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance)
# target_link_libraries(ckProfiler PRIVATE conv_fwd_util)
# target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_conv1d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv3d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance)
target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance)
target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance)
# target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance)
# target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance)
# target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance)
# target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance)
# target_link_libraries(ckProfiler PRIVATE device_conv1d_fwd_instance)
# target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance)
# target_link_libraries(ckProfiler PRIVATE device_conv3d_fwd_instance)
# target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance)
# target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance)
# target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance)
# target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance)
# target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
# target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
# target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance)
# target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance)
This diff is collapsed.
......@@ -7,19 +7,19 @@
#include "profile_convnd_fwd.hpp"
int profile_gemm(int, char*[]);
int profile_gemm_bias_2d(int, char*[]);
int profile_gemm_bias_relu(int, char*[]);
int profile_gemm_bias_relu_add(int, char*[]);
int profile_gemm_reduce(int, char*[]);
int profile_batched_gemm(int, char*[]);
int profile_grouped_gemm(int, char*[]);
int profile_conv_fwd_bias_relu(int, char*[]);
int profile_conv_fwd_bias_relu_add(int, char*[]);
int profile_conv_fwd_bias_relu_atomic_add(int, char*[]);
int profile_convnd_bwd_data(int, char*[], int);
int profile_reduce(int, char*[]);
int profile_conv_bwd_weight(int, char*[]);
int profile_batched_gemm_reduce(int, char*[]);
// int profile_gemm_bias_2d(int, char*[]);
// int profile_gemm_bias_relu(int, char*[]);
// int profile_gemm_bias_relu_add(int, char*[]);
// int profile_gemm_reduce(int, char*[]);
// int profile_batched_gemm(int, char*[]);
// int profile_grouped_gemm(int, char*[]);
// int profile_conv_fwd_bias_relu(int, char*[]);
// int profile_conv_fwd_bias_relu_add(int, char*[]);
// int profile_conv_fwd_bias_relu_atomic_add(int, char*[]);
// int profile_convnd_bwd_data(int, char*[], int);
// int profile_reduce(int, char*[]);
// int profile_conv_bwd_weight(int, char*[]);
// int profile_batched_gemm_reduce(int, char*[]);
int main(int argc, char* argv[])
{
......@@ -27,70 +27,70 @@ int main(int argc, char* argv[])
{
return profile_gemm(argc, argv);
}
else if(strcmp(argv[1], "gemm_bias_2d") == 0)
{
return profile_gemm_bias_2d(argc, argv);
}
else if(strcmp(argv[1], "gemm_bias_relu") == 0)
{
return profile_gemm_bias_relu(argc, argv);
}
else if(strcmp(argv[1], "gemm_bias_relu_add") == 0)
{
return profile_gemm_bias_relu_add(argc, argv);
}
else if(strcmp(argv[1], "gemm_reduce") == 0)
{
return profile_gemm_reduce(argc, argv);
}
else if(strcmp(argv[1], "batched_gemm") == 0)
{
return profile_batched_gemm(argc, argv);
}
else if(strcmp(argv[1], "batched_gemm_reduce") == 0)
{
return profile_batched_gemm_reduce(argc, argv);
}
else if(strcmp(argv[1], "grouped_gemm") == 0)
{
profile_grouped_gemm(argc, argv);
}
else if(strcmp(argv[1], "conv_fwd") == 0)
{
return ck::profiler::profile_convnd_fwd(argc, argv);
}
else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0)
{
return profile_conv_fwd_bias_relu(argc, argv);
}
else if(strcmp(argv[1], "conv_fwd_bias_relu_add") == 0)
{
return profile_conv_fwd_bias_relu_add(argc, argv);
}
else if(strcmp(argv[1], "conv_fwd_bias_relu_atomic_add") == 0)
{
return profile_conv_fwd_bias_relu_atomic_add(argc, argv);
}
else if(strcmp(argv[1], "conv1d_bwd_data") == 0)
{
return profile_convnd_bwd_data(argc, argv, 1);
}
else if(strcmp(argv[1], "conv2d_bwd_data") == 0)
{
return profile_convnd_bwd_data(argc, argv, 2);
}
else if(strcmp(argv[1], "conv3d_bwd_data") == 0)
{
return profile_convnd_bwd_data(argc, argv, 3);
}
else if(strcmp(argv[1], "reduce") == 0)
{
return profile_reduce(argc, argv);
}
else if(strcmp(argv[1], "conv2d_bwd_weight") == 0)
{
return profile_conv_bwd_weight(argc, argv);
}
// else if(strcmp(argv[1], "gemm_bias_2d") == 0)
// {
// return profile_gemm_bias_2d(argc, argv);
// }
// else if(strcmp(argv[1], "gemm_bias_relu") == 0)
// {
// return profile_gemm_bias_relu(argc, argv);
// }
// else if(strcmp(argv[1], "gemm_bias_relu_add") == 0)
// {
// return profile_gemm_bias_relu_add(argc, argv);
// }
// else if(strcmp(argv[1], "gemm_reduce") == 0)
// {
// return profile_gemm_reduce(argc, argv);
// }
// else if(strcmp(argv[1], "batched_gemm") == 0)
// {
// return profile_batched_gemm(argc, argv);
// }
// else if(strcmp(argv[1], "batched_gemm_reduce") == 0)
// {
// return profile_batched_gemm_reduce(argc, argv);
// }
// else if(strcmp(argv[1], "grouped_gemm") == 0)
// {
// profile_grouped_gemm(argc, argv);
// }
// else if(strcmp(argv[1], "conv_fwd") == 0)
// {
// return ck::profiler::profile_convnd_fwd(argc, argv);
// }
// else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0)
// {
// return profile_conv_fwd_bias_relu(argc, argv);
// }
// else if(strcmp(argv[1], "conv_fwd_bias_relu_add") == 0)
// {
// return profile_conv_fwd_bias_relu_add(argc, argv);
// }
// else if(strcmp(argv[1], "conv_fwd_bias_relu_atomic_add") == 0)
// {
// return profile_conv_fwd_bias_relu_atomic_add(argc, argv);
// }
// else if(strcmp(argv[1], "conv1d_bwd_data") == 0)
// {
// return profile_convnd_bwd_data(argc, argv, 1);
// }
// else if(strcmp(argv[1], "conv2d_bwd_data") == 0)
// {
// return profile_convnd_bwd_data(argc, argv, 2);
// }
// else if(strcmp(argv[1], "conv3d_bwd_data") == 0)
// {
// return profile_convnd_bwd_data(argc, argv, 3);
// }
// else if(strcmp(argv[1], "reduce") == 0)
// {
// return profile_reduce(argc, argv);
// }
// else if(strcmp(argv[1], "conv2d_bwd_weight") == 0)
// {
// return profile_conv_bwd_weight(argc, argv);
// }
else
{
// clang-format off
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment