Commit 91b19b3c authored by Chao Liu's avatar Chao Liu
Browse files

clean

parent 4511f877
...@@ -167,12 +167,14 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -167,12 +167,14 @@ struct DeviceGemmXdlSplitKCShuffle
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static auto GetActualBatchAndKSplitted(index_t KRaw, index_t KBatch) static auto GetActualBatchAndKPerBatch(index_t KRaw, index_t KBatchDesired)
{ {
const index_t KSplitted = math::integer_divide_ceil(KRaw, KPerBlock * KBatch) * KPerBlock; const index_t KPerBatch =
const index_t actual_k_batch = math::integer_divide_ceil(KRaw, KSplitted); math::integer_divide_ceil(KRaw, KPerBlock * KBatchDesired) * KPerBlock;
return std::make_pair(actual_k_batch, KSplitted); const index_t KBatch = math::integer_divide_ceil(KRaw, KPerBatch);
return std::make_tuple(KBatch, KPerBatch);
} }
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
...@@ -488,32 +490,33 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -488,32 +490,33 @@ struct DeviceGemmXdlSplitKCShuffle
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
index_t k_batch) index_t k_batch_desired)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
BatchCount_(k_batch), KBatch_{0},
has_k_batch_tail_{false},
compute_ptr_offset_of_batch_{0, 0}, compute_ptr_offset_of_batch_{0, 0},
block_2_ctile_map_{}, block_2_ctile_map_{},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
const auto actual_batch_and_ksplitted = GetActualBatchAndKSplitted(KRaw, k_batch); index_t KPerBatch = 0;
BatchCount_ = actual_batch_and_ksplitted.first;
const auto KSplitted = actual_batch_and_ksplitted.second; std::tie(KBatch_, KPerBatch) = GetActualBatchAndKPerBatch(KRaw, k_batch_desired);
a_grid_desc_ak0_m_ak1_ = a_grid_desc_ak0_m_ak1_ =
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KSplitted, StrideA); DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KPerBatch, StrideA);
b_grid_desc_bk0_n_bk1_ = b_grid_desc_bk0_n_bk1_ =
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KSplitted, NRaw, StrideB); DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KPerBatch, NRaw, StrideB);
c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC); c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC);
if(KRaw != KSplitted * BatchCount_ || KRaw != KSplitted * BatchCount_) if(KRaw != KPerBatch * KBatch_)
{ {
const auto KTail = KRaw - KSplitted * (BatchCount_ - 1); has_k_batch_tail_ = true;
const auto KTail = KRaw - KPerBatch * (KBatch_ - 1);
a_grid_desc_ak0_m_ak1_tail_ = a_grid_desc_ak0_m_ak1_tail_ =
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KTail, StrideA); DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KTail, StrideA);
...@@ -525,27 +528,27 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -525,27 +528,27 @@ struct DeviceGemmXdlSplitKCShuffle
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_); c_grid_desc_m_n_);
const index_t a_batch_stride = [KSplitted, StrideA]() { const index_t a_batch_stride = [KPerBatch, StrideA]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{ {
ignore = StrideA; ignore = StrideA;
return KSplitted; return KPerBatch;
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{ {
return KSplitted * StrideA; return KPerBatch * StrideA;
} }
}(); }();
const index_t b_batch_stride = [KSplitted, StrideB]() { const index_t b_batch_stride = [KPerBatch, StrideB]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
return KSplitted * StrideB; return KPerBatch * StrideB;
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{ {
ignore = StrideB; ignore = StrideB;
return KSplitted; return KPerBatch;
} }
}(); }();
...@@ -553,14 +556,15 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -553,14 +556,15 @@ struct DeviceGemmXdlSplitKCShuffle
ComputePtrOffsetOfStridedBatch{a_batch_stride, b_batch_stride}; ComputePtrOffsetOfStridedBatch{a_batch_stride, b_batch_stride};
block_2_ctile_map_ = MakeBlock2CTileMap( block_2_ctile_map_ = MakeBlock2CTileMap(
BatchCount_, c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1), 1, 1); KBatch_, c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1), 1, 1);
} }
// private: // private:
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
index_t BatchCount_; index_t KBatch_;
bool has_k_batch_tail_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_tail_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_tail_;
...@@ -588,7 +592,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -588,7 +592,7 @@ struct DeviceGemmXdlSplitKCShuffle
} }
{ {
std::cout << "k_batch = " << arg.BatchCount_ << "\n"; std::cout << "k_batch_desired = " << arg.KBatch_ << "\n";
std::cout << "arg.a_grid_desc_ak0_m_ak1_{" std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
...@@ -614,7 +618,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -614,7 +618,7 @@ struct DeviceGemmXdlSplitKCShuffle
} }
const index_t grid_size = const index_t grid_size =
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.KBatch_;
const auto K0 = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0); const auto K0 = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
...@@ -634,7 +638,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -634,7 +638,7 @@ struct DeviceGemmXdlSplitKCShuffle
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.BatchCount_, arg.KBatch_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
...@@ -742,11 +746,20 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -742,11 +746,20 @@ struct DeviceGemmXdlSplitKCShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
return GridwiseGemm::CheckValidity( if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_) && arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_))
GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_tail_, {
arg.b_grid_desc_bk0_n_bk1_tail_, return false;
arg.c_grid_desc_m_n_); }
if(arg.has_k_batch_tail_ && !GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_tail_,
arg.b_grid_desc_bk0_n_bk1_tail_,
arg.c_grid_desc_m_n_))
{
return false;
}
return true;
} }
// polymorphic // polymorphic
......
...@@ -98,9 +98,6 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -98,9 +98,6 @@ bool profile_gemm_splitk_impl(int do_verification,
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
} }
// set zero to c_device_buf
c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0<CDataType>{}, num_thread);
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
...@@ -115,7 +112,6 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -115,7 +112,6 @@ bool profile_gemm_splitk_impl(int do_verification,
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
// add device GEMM instances // add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmNoOpPtr> gemm_ptrs; std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmNoOpPtr> gemm_ptrs;
...@@ -196,6 +192,8 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -196,6 +192,8 @@ bool profile_gemm_splitk_impl(int do_verification,
// profile device GEMM instances // profile device GEMM instances
for(auto& gemm_ptr : gemm_ptrs) for(auto& gemm_ptr : gemm_ptrs)
{ {
std::cout << gemm_ptr->GetTypeString() << std::endl;
auto argument_ptr = auto argument_ptr =
gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
...@@ -263,6 +261,7 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -263,6 +261,7 @@ bool profile_gemm_splitk_impl(int do_verification,
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
pass = pass && pass = pass &&
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
......
...@@ -57,309 +57,95 @@ bool profile_gemm(int argc, char* argv[]) ...@@ -57,309 +57,95 @@ bool profile_gemm(int argc, char* argv[])
const int StrideB = std::stoi(argv[12]); const int StrideB = std::stoi(argv[12]);
const int StrideC = std::stoi(argv[13]); const int StrideC = std::stoi(argv[13]);
auto profile =
[&](auto a_type, auto b_type, auto c_type, auto a_layout, auto b_layout, auto c_layout) {
using ADataType = decltype(a_type);
using BDataType = decltype(b_type);
using CDataType = decltype(c_type);
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout);
using CLayout = decltype(c_layout);
return ck::profiler::
profile_gemm_impl<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
};
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return ck::profiler::profile_gemm_impl<ck::half_t, return profile(ck::half_t{}, ck::half_t{}, ck::half_t{}, Row{}, Row{}, Row{});
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
return ck::profiler::profile_gemm_impl<ck::half_t, return profile(ck::half_t{}, ck::half_t{}, ck::half_t{}, Row{}, Col{}, Row{});
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
return ck::profiler::profile_gemm_impl<ck::half_t, return profile(ck::half_t{}, ck::half_t{}, ck::half_t{}, Col{}, Row{}, Row{});
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
return ck::profiler::profile_gemm_impl<ck::half_t, return profile(ck::half_t{}, ck::half_t{}, ck::half_t{}, Col{}, Col{}, Row{});
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return ck::profiler::profile_gemm_impl<float, return profile(float{}, float{}, float{}, Row{}, Row{}, Row{});
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
return ck::profiler::profile_gemm_impl<float, return profile(float{}, float{}, float{}, Row{}, Col{}, Row{});
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
return ck::profiler::profile_gemm_impl<float, return profile(float{}, float{}, float{}, Col{}, Row{}, Row{});
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
return ck::profiler::profile_gemm_impl<float, return profile(float{}, float{}, float{}, Col{}, Col{}, Row{});
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
} }
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return ck::profiler::profile_gemm_impl<int8_t, return profile(int8_t{}, int8_t{}, int8_t{}, Row{}, Row{}, Row{});
int8_t,
int8_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
} }
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
return ck::profiler::profile_gemm_impl<int8_t, return profile(int8_t{}, int8_t{}, int8_t{}, Row{}, Col{}, Row{});
int8_t,
int8_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
} }
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
return ck::profiler::profile_gemm_impl<int8_t, return profile(int8_t{}, int8_t{}, int8_t{}, Col{}, Row{}, Row{});
int8_t,
int8_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
} }
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
return ck::profiler::profile_gemm_impl<int8_t, return profile(int8_t{}, int8_t{}, int8_t{}, Col{}, Col{}, Row{});
int8_t,
int8_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
} }
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return ck::profiler::profile_gemm_impl<ck::bhalf_t, return profile(ck::bhalf_t{}, ck::bhalf_t{}, ck::bhalf_t{}, Row{}, Row{}, Row{});
ck::bhalf_t,
ck::bhalf_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
} }
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
return ck::profiler::profile_gemm_impl<ck::bhalf_t, return profile(ck::bhalf_t{}, ck::bhalf_t{}, ck::bhalf_t{}, Row{}, Col{}, Row{});
ck::bhalf_t,
ck::bhalf_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
} }
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
return ck::profiler::profile_gemm_impl<ck::bhalf_t, return profile(ck::bhalf_t{}, ck::bhalf_t{}, ck::bhalf_t{}, Col{}, Row{}, Row{});
ck::bhalf_t,
ck::bhalf_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC);
} }
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
return ck::profiler::profile_gemm_impl<ck::bhalf_t, return profile(ck::bhalf_t{}, ck::bhalf_t{}, ck::bhalf_t{}, Col{}, Col{}, Row{});
ck::bhalf_t,
ck::bhalf_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC);
} }
else else
{ {
......
...@@ -25,7 +25,7 @@ bool profile_gemm_splitk(int argc, char* argv[]) ...@@ -25,7 +25,7 @@ bool profile_gemm_splitk(int argc, char* argv[])
if(argc != 15) if(argc != 15)
{ {
printf("arg1: tensor operation (gemm: GEMM)\n"); printf("arg1: tensor operation (gemm: GEMMSplitK)\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
...@@ -56,165 +56,68 @@ bool profile_gemm_splitk(int argc, char* argv[]) ...@@ -56,165 +56,68 @@ bool profile_gemm_splitk(int argc, char* argv[])
const int StrideC = std::stoi(argv[13]); const int StrideC = std::stoi(argv[13]);
const int KBatch = std::stoi(argv[14]); const int KBatch = std::stoi(argv[14]);
auto profile = [&](auto a_type,
auto b_type,
auto c_type,
auto a_layout,
auto b_layout,
auto c_layout) {
using ADataType = decltype(a_type);
using BDataType = decltype(b_type);
using CDataType = decltype(c_type);
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout);
using CLayout = decltype(c_layout);
return ck::profiler::
profile_gemm_splitk_impl<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
};
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return ck::profiler::profile_gemm_splitk_impl<ck::half_t, return profile(ck::half_t{}, ck::half_t{}, ck::half_t{}, Row{}, Row{}, Row{});
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
return ck::profiler::profile_gemm_splitk_impl<ck::half_t, return profile(ck::half_t{}, ck::half_t{}, ck::half_t{}, Row{}, Col{}, Row{});
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
return ck::profiler::profile_gemm_splitk_impl<ck::half_t, return profile(ck::half_t{}, ck::half_t{}, ck::half_t{}, Col{}, Row{}, Row{});
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
return ck::profiler::profile_gemm_splitk_impl<ck::half_t, return profile(ck::half_t{}, ck::half_t{}, ck::half_t{}, Col{}, Col{}, Row{});
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return ck::profiler::profile_gemm_splitk_impl<float, return profile(float{}, float{}, float{}, Row{}, Row{}, Row{});
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
return ck::profiler::profile_gemm_splitk_impl<float, return profile(float{}, float{}, float{}, Row{}, Col{}, Row{});
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
return ck::profiler::profile_gemm_splitk_impl<float, return profile(float{}, float{}, float{}, Col{}, Row{}, Row{});
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
return ck::profiler::profile_gemm_splitk_impl<float, return profile(float{}, float{}, float{}, Col{}, Col{}, Row{});
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
} }
else else
{ {
......
...@@ -22,13 +22,39 @@ bool profile_batched_gemm_reduce(int, char*[]); ...@@ -22,13 +22,39 @@ bool profile_batched_gemm_reduce(int, char*[]);
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
auto print_help_message = []() {
// clang-format off
printf("arg1: tensor operation, gemm: GEMM\n"
" gemm_splitk: GEMM Split-K\n"
" gemm_bias_2d: GEMM+Bias(2D)\n"
" gemm_bias_relu: GEMM+Bias+ReLU\n"
" gemm_bias_relu_add: GEMM+Bias+ReLU+Add\n"
" gemm_reduce: GEMM+Reduce\n"
" grouped_gemm: Grouped GEMM\n"
" conv_fwd: Convolution Forward\n"
" conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n"
" conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n"
" conv1d_bwd_data: Convolution Backward Data 1D\n"
" conv2d_bwd_data: Convolution Backward Data 2D\n"
" conv3d_bwd_data: Convolution Backward Data 3D\n"
" reduce: Reduce\n"
" conv2d_bwd_weight: Convolution Backward Weight 2D\n");
// clang-format on
};
if(argc < 2)
{
print_help_message();
exit(1);
}
bool pass = true; bool pass = true;
if(strcmp(argv[1], "gemm") == 0) if(strcmp(argv[1], "gemm") == 0)
{ {
pass = profile_gemm(argc, argv); pass = profile_gemm(argc, argv);
} }
if(strcmp(argv[1], "gemm_splitk") == 0) else if(strcmp(argv[1], "gemm_splitk") == 0)
{ {
pass = profile_gemm_splitk(argc, argv); pass = profile_gemm_splitk(argc, argv);
} }
...@@ -94,23 +120,7 @@ int main(int argc, char* argv[]) ...@@ -94,23 +120,7 @@ int main(int argc, char* argv[])
} }
else else
{ {
// clang-format off print_help_message();
printf("arg1: tensor operation, gemm: GEMM\n"
" gemm_splitk: GEMMSplitK)\n"
" gemm_bias_2d: GEMM+Bias(2D)\n"
" gemm_bias_relu: GEMM+Bias+ReLU\n"
" gemm_bias_relu_add: GEMM+Bias+ReLU+Add\n"
" gemm_reduce: GEMM+Reduce\n"
" grouped_gemm: Grouped GEMM\n"
" conv_fwd: ForwardConvolution\n"
" conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n"
" conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n"
" conv1d_bwd_data: BackwardConvolution data 1 dim\n"
" conv2d_bwd_data: BackwardConvolution data 2 dim\n"
" conv3d_bwd_data: BackwardConvolution data 3 dim\n"
" reduce: Reduce\n"
" conv2d_bwd_weight: Backward Weight Convolution 2d\n");
// clang-format on
} }
return pass ? 0 : 1; return pass ? 0 : 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment