Commit 4511f877 authored by Chao Liu's avatar Chao Liu
Browse files

refactor profiler

parent 519b6aaf
......@@ -60,8 +60,6 @@ bool profile_gemm_reduce_impl(int do_verification,
int StrideB,
int StrideC)
{
bool pass = true;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
......@@ -209,15 +207,13 @@ bool profile_gemm_reduce_impl(int do_verification,
}
}
if(gemm_ptrs.size() <= 0)
{
throw std::runtime_error("wrong! no device GEMM instance found");
}
std::cout << "found " << gemm_ptrs.size() << " instances" << std::endl;
std::string best_gemm_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
bool pass = true;
// profile device GEMM instances
for(auto& gemm_ptr : gemm_ptrs)
......
#pragma once
#include <iomanip>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_conv.hpp"
#include "tensor_layout.hpp"
#include "device_tensor.hpp"
#include "element_wise_operation.hpp"
#include "device_gemm.hpp"
#include "reference_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using DeviceGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_kn_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
namespace ck {
namespace profiler {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
bool profile_gemm_splitk_impl(int do_verification,
int init_method,
bool do_log,
int nrepeat,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideC,
int KBatch)
{
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl;
std::size_t num_thread = 1;
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
}
// set zero to c_device_buf
c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0<CDataType>{}, num_thread);
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
// add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmNoOpPtr> gemm_ptrs;
if constexpr(is_same<ADataType, float>::value && is_same<BDataType, float>::value &&
is_same<CDataType, float>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
}
}
else if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
}
}
std::cout << "found " << gemm_ptrs.size() << " instances" << std::endl;
std::string best_gemm_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
bool pass = true;
// profile device GEMM instances
for(auto& gemm_ptr : gemm_ptrs)
{
auto argument_ptr =
gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
KBatch);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
std::string gemm_name = gemm_ptr->GetTypeString();
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat);
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << gemm_name << std::endl;
if(tflops > best_tflops)
{
best_gemm_name = gemm_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
Tensor<CDataType> c_m_n_host_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
pass = pass &&
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host : ", c_m_n_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",")
<< std::endl;
}
}
}
else
{
std::cout << "does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
return pass;
}
} // namespace profiler
} // namespace ck
......@@ -46,7 +46,7 @@ template <typename ADataType,
typename ALayout,
typename BLayout,
typename CLayout>
void profile_grouped_gemm_impl(int do_verification,
bool profile_grouped_gemm_impl(int do_verification,
int init_method,
bool do_log,
int nrepeat,
......@@ -57,6 +57,8 @@ void profile_grouped_gemm_impl(int do_verification,
std::vector<int> StrideBs,
std::vector<int> StrideCs)
{
bool pass = true;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
......@@ -81,6 +83,7 @@ void profile_grouped_gemm_impl(int do_verification,
std::vector<Tensor<ADataType>> a_m_k;
std::vector<Tensor<BDataType>> b_k_n;
std::vector<Tensor<CDataType>> c_m_n_host_results;
std::vector<Tensor<CDataType>> c_m_n_device_results;
for(int i = 0; i < Ms.size(); i++)
......@@ -90,6 +93,9 @@ void profile_grouped_gemm_impl(int do_verification,
b_k_n.push_back(
Tensor<BDataType>(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{})));
c_m_n_host_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
c_m_n_device_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
......@@ -121,11 +127,6 @@ void profile_grouped_gemm_impl(int do_verification,
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
// if(do_verification)
// {
// }
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_device_buf, b_device_buf, c_device_buf;
......@@ -165,6 +166,27 @@ void profile_grouped_gemm_impl(int do_verification,
p_c.push_back(c_device_buf[i]->GetDeviceBuffer());
}
// reference calculation
if(do_verification)
{
for(int i = 0; i < gemm_shapes.size(); i++)
{
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_m_k[i],
b_k_n[i],
c_m_n_host_results[i],
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
}
}
// add device GEMM instances
std::vector<
ck::tensor_operation::device::device_grouped_gemm_instance::DeviceGroupedGemmNoOpPtr>
......@@ -229,6 +251,12 @@ void profile_grouped_gemm_impl(int do_verification,
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init C to zero before profiling next kernel
for(int i = 0; i < gemm_shapes.size(); i++)
{
c_device_buf[i]->SetZero();
}
std::string gemm_name = gemm_ptr->GetTypeString();
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat);
......@@ -260,32 +288,10 @@ void profile_grouped_gemm_impl(int do_verification,
{
for(int i = 0; i < gemm_shapes.size(); i++)
{
c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data());
Tensor<CDataType> c_m_n_host_result(
f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}));
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_m_k[i],
b_k_n[i],
c_m_n_host_result,
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
ck::utils::check_err(c_m_n_device_results[i].mData, c_m_n_host_result.mData);
pass = pass && ck::utils::check_err(c_m_n_device_results[i].mData,
c_m_n_host_results[i].mData);
if(do_log)
{
......@@ -296,7 +302,7 @@ void profile_grouped_gemm_impl(int do_verification,
std::cout << "c_device: ", c_m_n_device_results[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_host : ", c_m_n_host_result.mData, ",")
std::cout << "c_host : ", c_m_n_host_results[i].mData, ",")
<< std::endl;
}
}
......@@ -310,6 +316,9 @@ void profile_grouped_gemm_impl(int do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
return pass;
} // namespace profiler
} // namespace profiler
......
......@@ -16,8 +16,10 @@
#include "device_batched_gemm_xdl.hpp"
#include "profile_batched_gemm_impl.hpp"
enum struct GemmMatrixLayout
bool profile_batched_gemm(int argc, char* argv[])
{
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
KM_KN_MN, // 2
......@@ -26,18 +28,16 @@ enum struct GemmMatrixLayout
MK_NK_NM, // 5
KM_KN_NM, // 6
KM_NK_NM, // 7
};
};
enum struct GemmDataType
{
enum struct GemmDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3
};
};
int profile_batched_gemm(int argc, char* argv[])
{
if(!(argc == 15))
{
printf("arg1: tensor operation (batched_gemm: Batched GEMM)\n");
......@@ -51,7 +51,7 @@ int profile_batched_gemm(int argc, char* argv[])
printf("arg8: print tensor value (0: no; 1: yes)\n");
printf("arg7: run kernel # of times (>1)\n");
printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, BatchCount\n");
exit(1);
return false;
}
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
......@@ -73,7 +73,7 @@ int profile_batched_gemm(int argc, char* argv[])
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_batched_gemm_impl<ck::half_t,
return ck::profiler::profile_batched_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
......@@ -93,7 +93,7 @@ int profile_batched_gemm(int argc, char* argv[])
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_batched_gemm_impl<ck::half_t,
return ck::profiler::profile_batched_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
......@@ -113,7 +113,7 @@ int profile_batched_gemm(int argc, char* argv[])
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_batched_gemm_impl<ck::half_t,
return ck::profiler::profile_batched_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -133,7 +133,7 @@ int profile_batched_gemm(int argc, char* argv[])
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_batched_gemm_impl<ck::half_t,
return ck::profiler::profile_batched_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -153,7 +153,7 @@ int profile_batched_gemm(int argc, char* argv[])
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_batched_gemm_impl<ck::bhalf_t,
return ck::profiler::profile_batched_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
ck::tensor_layout::gemm::RowMajor,
......@@ -173,7 +173,7 @@ int profile_batched_gemm(int argc, char* argv[])
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_batched_gemm_impl<ck::bhalf_t,
return ck::profiler::profile_batched_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
ck::tensor_layout::gemm::RowMajor,
......@@ -193,7 +193,7 @@ int profile_batched_gemm(int argc, char* argv[])
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_batched_gemm_impl<ck::bhalf_t,
return ck::profiler::profile_batched_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -213,7 +213,7 @@ int profile_batched_gemm(int argc, char* argv[])
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_batched_gemm_impl<ck::bhalf_t,
return ck::profiler::profile_batched_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -233,7 +233,7 @@ int profile_batched_gemm(int argc, char* argv[])
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_batched_gemm_impl<float,
return ck::profiler::profile_batched_gemm_impl<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
......@@ -253,7 +253,7 @@ int profile_batched_gemm(int argc, char* argv[])
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_batched_gemm_impl<float,
return ck::profiler::profile_batched_gemm_impl<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
......@@ -273,7 +273,7 @@ int profile_batched_gemm(int argc, char* argv[])
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_batched_gemm_impl<float,
return ck::profiler::profile_batched_gemm_impl<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -293,7 +293,7 @@ int profile_batched_gemm(int argc, char* argv[])
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_batched_gemm_impl<float,
return ck::profiler::profile_batched_gemm_impl<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -313,7 +313,7 @@ int profile_batched_gemm(int argc, char* argv[])
}
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_batched_gemm_impl<int8_t,
return ck::profiler::profile_batched_gemm_impl<int8_t,
int8_t,
int8_t,
ck::tensor_layout::gemm::RowMajor,
......@@ -333,7 +333,7 @@ int profile_batched_gemm(int argc, char* argv[])
}
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_batched_gemm_impl<int8_t,
return ck::profiler::profile_batched_gemm_impl<int8_t,
int8_t,
int8_t,
ck::tensor_layout::gemm::RowMajor,
......@@ -353,7 +353,7 @@ int profile_batched_gemm(int argc, char* argv[])
}
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_batched_gemm_impl<int8_t,
return ck::profiler::profile_batched_gemm_impl<int8_t,
int8_t,
int8_t,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -373,7 +373,7 @@ int profile_batched_gemm(int argc, char* argv[])
}
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_batched_gemm_impl<int8_t,
return ck::profiler::profile_batched_gemm_impl<int8_t,
int8_t,
int8_t,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -393,8 +393,8 @@ int profile_batched_gemm(int argc, char* argv[])
}
else
{
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
}
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
return true;
}
}
......@@ -7,7 +7,7 @@
#include "profile_batched_gemm_reduce_impl.hpp"
int profile_batched_gemm_reduce(int argc, char* argv[])
bool profile_batched_gemm_reduce(int argc, char* argv[])
{
enum struct GemmMatrixLayout
{
......@@ -23,7 +23,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
F16_F16_F16_F32_F32, // 1
};
if(!(argc == 15 || argc == 16))
if(argc != 15)
{
printf("arg1: tensor operation (batched_gemm: BatchedGEMM+Reduce)\n");
printf("arg2: data type (0: fp32; 1: fp16)\n");
......@@ -36,8 +36,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
printf("arg8: print tensor value (0: no; 1: yes)\n");
printf("arg7: run kernel # of times (>1)\n");
printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, BatchCount\n");
printf("arg15: split k into mulitiple batch\n");
exit(1);
return false;
}
const auto data_type = static_cast<GemmReduceDataType>(std::stoi(argv[2]));
......@@ -59,7 +58,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t,
return ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
......@@ -81,7 +80,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 &&
layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t,
return ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
......@@ -103,7 +102,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 &&
layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t,
return ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
......@@ -125,7 +124,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 &&
layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t,
return ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
......@@ -146,8 +145,8 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
}
else
{
throw std::runtime_error("wrong! this data_type & layout is not implemented");
}
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
return true;
}
}
......@@ -4,36 +4,37 @@
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "profile_conv_bwd_data_impl.hpp"
enum struct ConvDataType
int profile_conv_bwd_data(int argc, char* argv[])
{
enum struct ConvDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3
};
};
enum struct ConvInputLayout
{
enum struct ConvInputLayout
{
NCHW, // 0
NHWC, // 1
};
};
enum struct ConvWeightLayout
{
enum struct ConvWeightLayout
{
KCYX, // 0
KYXC, // 1
};
};
enum struct ConvOutputLayout
{
enum struct ConvOutputLayout
{
NKHW, // 0
NHWK, // 1
};
};
int profile_conv_bwd_data(int argc, char* argv[])
{
if(argc != 25)
{
printf("arg1: tensor operation (conv_bwd: BackwardConvolution)\n");
......@@ -47,7 +48,7 @@ int profile_conv_bwd_data(int argc, char* argv[])
printf("arg9: run kernel # of times (>1)\n");
printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n");
exit(1);
return false;
}
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
......@@ -85,7 +86,7 @@ int profile_conv_bwd_data(int argc, char* argv[])
if(data_type == ConvDataType::F32_F32_F32 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{
ck::profiler::profile_conv_bwd_data_impl<2,
return ck::profiler::profile_conv_bwd_data_impl<2,
float,
float,
float,
......@@ -111,7 +112,7 @@ int profile_conv_bwd_data(int argc, char* argv[])
else if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{
ck::profiler::profile_conv_bwd_data_impl<2,
return ck::profiler::profile_conv_bwd_data_impl<2,
ck::half_t,
ck::half_t,
ck::half_t,
......@@ -137,7 +138,7 @@ int profile_conv_bwd_data(int argc, char* argv[])
else if(data_type == ConvDataType::BF16_BF16_BF16 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{
ck::profiler::profile_conv_bwd_data_impl<2,
return ck::profiler::profile_conv_bwd_data_impl<2,
uint16_t,
uint16_t,
uint16_t,
......@@ -163,7 +164,7 @@ int profile_conv_bwd_data(int argc, char* argv[])
else if(data_type == ConvDataType::INT8_INT8_INT8 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{
ck::profiler::profile_conv_bwd_data_impl<2,
return ck::profiler::profile_conv_bwd_data_impl<2,
int8_t,
int8_t,
int8_t,
......@@ -188,8 +189,8 @@ int profile_conv_bwd_data(int argc, char* argv[])
}
else
{
throw std::runtime_error("wrong! this Conv data_type & layout is not implemented");
}
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
return true;
}
}
......@@ -6,34 +6,35 @@
#include <half.hpp>
#include "profile_conv_bwd_weight_impl.hpp"
enum struct ConvDataType
// return true if test pass
bool profile_conv_bwd_weight(int argc, char* argv[])
{
enum struct ConvDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3
};
};
enum struct ConvInputLayout
{
enum struct ConvInputLayout
{
NCHW, // 0
NHWC, // 1
};
};
enum struct ConvWeightLayout
{
enum struct ConvWeightLayout
{
KCYX, // 0
KYXC, // 1
};
};
enum struct ConvOutputLayout
{
enum struct ConvOutputLayout
{
NKHW, // 0
NHWK, // 1
};
};
int profile_conv_bwd_weight(int argc, char* argv[])
{
if(argc != 26)
{
printf("arg1: tensor operation (conv_fwd: ForwardConvolution)\n");
......@@ -48,7 +49,7 @@ int profile_conv_bwd_weight(int argc, char* argv[])
printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n");
printf("arg25: split k (>=1)\n");
exit(1);
return false;
}
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
......@@ -88,7 +89,7 @@ int profile_conv_bwd_weight(int argc, char* argv[])
if(data_type == ConvDataType::F32_F32_F32 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{
ck::profiler::profile_conv_bwd_weight_impl<2,
return ck::profiler::profile_conv_bwd_weight_impl<2,
float,
float,
float,
......@@ -114,7 +115,7 @@ int profile_conv_bwd_weight(int argc, char* argv[])
else if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{
ck::profiler::profile_conv_bwd_weight_impl<2,
return ck::profiler::profile_conv_bwd_weight_impl<2,
ck::half_t,
ck::half_t,
ck::half_t,
......@@ -139,8 +140,8 @@ int profile_conv_bwd_weight(int argc, char* argv[])
}
else
{
throw std::runtime_error("wrong! this Conv data_type & layout is not implemented");
}
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
return true;
}
}
......@@ -6,32 +6,33 @@
#include <half.hpp>
#include "profile_conv_fwd_bias_relu_impl.hpp"
enum struct ConvDataType
// return true if test pass
bool profile_conv_fwd_bias_relu(int argc, char* argv[])
{
enum struct ConvDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
};
};
enum struct ConvInputLayout
{
enum struct ConvInputLayout
{
NCHW, // 0
NHWC, // 1
};
};
enum struct ConvWeightLayout
{
enum struct ConvWeightLayout
{
KCYX, // 0
KYXC, // 1
};
};
enum struct ConvOutputLayout
{
enum struct ConvOutputLayout
{
NKHW, // 0
NHWK, // 1
};
};
int profile_conv_fwd_bias_relu(int argc, char* argv[])
{
if(argc != 25)
{
printf("arg1: tensor operation (conv_fwd_bias_relu: ForwardConvolution+Bias+ReLu)\n");
......@@ -45,7 +46,7 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[])
printf("arg9: run kernel # of times (>1)\n");
printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n");
exit(1);
return false;
}
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
......@@ -83,7 +84,7 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[])
if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{
ck::profiler::profile_conv_fwd_bias_relu_impl<2,
return ck::profiler::profile_conv_fwd_bias_relu_impl<2,
ck::half_t,
ck::half_t,
ck::half_t,
......@@ -107,8 +108,8 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[])
}
else
{
throw std::runtime_error("wrong! data_type & layout for this operator is not implemented");
}
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
return true;
}
}
......@@ -6,32 +6,32 @@
#include <half.hpp>
#include "profile_conv_fwd_bias_relu_add_impl.hpp"
enum struct ConvDataType
bool profile_conv_fwd_bias_relu_add(int argc, char* argv[])
{
enum struct ConvDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
};
};
enum struct ConvInputLayout
{
enum struct ConvInputLayout
{
NCHW, // 0
NHWC, // 1
};
};
enum struct ConvWeightLayout
{
enum struct ConvWeightLayout
{
KCYX, // 0
KYXC, // 1
};
};
enum struct ConvOutputLayout
{
enum struct ConvOutputLayout
{
NKHW, // 0
NHWK, // 1
};
};
int profile_conv_fwd_bias_relu_add(int argc, char* argv[])
{
if(argc != 25)
{
printf(
......@@ -46,7 +46,7 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[])
printf("arg9: run kernel # of times (>1)\n");
printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n");
exit(1);
return false;
}
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
......@@ -84,7 +84,8 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[])
if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{
ck::profiler::profile_conv_fwd_bias_relu_add_impl<2,
return ck::profiler::profile_conv_fwd_bias_relu_add_impl<
2,
ck::half_t,
ck::half_t,
ck::half_t,
......@@ -108,8 +109,8 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[])
}
else
{
throw std::runtime_error("wrong! data_type & layout for this operator is not implemented");
}
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
return true;
}
}
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "profile_conv_fwd_bias_relu_atomic_add_impl.hpp"
enum struct ConvDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
};
enum struct ConvInputLayout
{
NCHW, // 0
NHWC, // 1
};
enum struct ConvWeightLayout
{
KCYX, // 0
KYXC, // 1
};
enum struct ConvOutputLayout
{
NKHW, // 0
NHWK, // 1
};
int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[])
{
if(argc != 25)
{
printf("arg1: tensor operation (conv_fwd_bias_relu_atomic_add: "
"ForwardConvolution+Bias+ReLu+AtomicAdd)\n");
printf("arg2: data type (0: fp32; 1: fp16)\n");
printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n");
printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n");
printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n");
printf("arg6: verification (0: no; 1: yes)\n");
printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg8: print tensor value (0: no; 1: yes)\n");
printf("arg9: run kernel # of times (>1)\n");
printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n");
exit(1);
}
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const auto in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const auto wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const auto out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
const bool do_verification = std::stoi(argv[6]);
const int init_method = std::stoi(argv[7]);
const bool do_log = std::stoi(argv[8]);
const int nrepeat = std::stoi(argv[9]);
const ck::index_t N = std::stoi(argv[10]);
const ck::index_t K = std::stoi(argv[11]);
const ck::index_t C = std::stoi(argv[12]);
const ck::index_t Y = std::stoi(argv[13]);
const ck::index_t X = std::stoi(argv[14]);
const ck::index_t Hi = std::stoi(argv[15]);
const ck::index_t Wi = std::stoi(argv[16]);
const ck::index_t conv_stride_h = std::stoi(argv[17]);
const ck::index_t conv_stride_w = std::stoi(argv[18]);
const ck::index_t conv_dilation_h = std::stoi(argv[19]);
const ck::index_t conv_dilation_w = std::stoi(argv[20]);
const ck::index_t in_left_pad_h = std::stoi(argv[21]);
const ck::index_t in_left_pad_w = std::stoi(argv[22]);
const ck::index_t in_right_pad_h = std::stoi(argv[23]);
const ck::index_t in_right_pad_w = std::stoi(argv[24]);
const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1;
const ck::index_t XEff = (X - 1) * conv_dilation_w + 1;
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{
ck::profiler::profile_conv_fwd_bias_relu_atomic_add_impl<
2,
ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::NHWK>(
do_verification,
init_method,
do_log,
nrepeat,
N,
K,
C,
std::vector<ck::index_t>{Hi, Wi},
std::vector<ck::index_t>{Y, X},
std::vector<ck::index_t>{Ho, Wo},
std::vector<ck::index_t>{conv_stride_h, conv_stride_w},
std::vector<ck::index_t>{conv_dilation_h, conv_dilation_w},
std::vector<ck::index_t>{in_left_pad_h, in_left_pad_w},
std::vector<ck::index_t>{in_right_pad_h, in_right_pad_w});
}
else
{
throw std::runtime_error("wrong! data_type & layout for this operator is not implemented");
}
return 1;
}
......@@ -80,7 +80,7 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[],
} // namespace
int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
bool profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
{
const int preParams = 10;
int conv_args = 3 + num_dim_spatial * 6;
......@@ -98,7 +98,7 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
printf("arg9: run kernel # of times (>1)\n");
printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n");
return 1;
return false;
}
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
......@@ -121,7 +121,7 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
switch(num_dim_spatial)
{
case 1:
ck::profiler::profile_convnd_bwd_data_impl<1,
return ck::profiler::profile_convnd_bwd_data_impl<1,
InDataType,
WeiDataType,
OutDataType,
......@@ -146,7 +146,7 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
break;
case 2:
ck::profiler::profile_convnd_bwd_data_impl<2,
return ck::profiler::profile_convnd_bwd_data_impl<2,
InDataType,
WeiDataType,
OutDataType,
......@@ -171,15 +171,15 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
break;
case 3:
ck::profiler::profile_convnd_bwd_data_impl<3,
return ck::profiler::profile_convnd_bwd_data_impl<
3,
InDataType,
WeiDataType,
OutDataType,
AccDataType,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK>(
do_verification,
ck::tensor_layout::convolution::NDHWK>(do_verification,
init_method,
do_log,
nrepeat,
......@@ -195,34 +195,34 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
params.input_right_pads);
break;
default: break;
default: return false;
}
};
if(data_type == ConvDataType::F32_F32_F32 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{
Run(float{}, float{}, float{}, float{});
return Run(float{}, float{}, float{}, float{});
}
else if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{
Run(ck::half_t{}, ck::half_t{}, ck::half_t{}, float{});
return Run(ck::half_t{}, ck::half_t{}, ck::half_t{}, float{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{
Run(ck::bhalf_t{}, ck::bhalf_t{}, ck::bhalf_t{}, float{});
return Run(ck::bhalf_t{}, ck::bhalf_t{}, ck::bhalf_t{}, float{});
}
else if(data_type == ConvDataType::INT8_INT8_INT8 && in_layout == ConvInputLayout::NHWC &&
wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK)
{
Run(int8_t{}, int8_t{}, int8_t{}, int32_t{});
return Run(int8_t{}, int8_t{}, int8_t{}, int32_t{});
}
else
{
std::cout << "wrong! this Conv data_type & layout is not implemented" << std::endl;
return 1;
}
std::cout << "this data_type & layout is not implemented" << std::endl;
return 0;
return true;
}
}
......@@ -8,7 +8,6 @@
#include "conv_fwd_util.hpp"
#include "element_wise_operation.hpp"
#include "fill.hpp"
#include "profile_convnd_fwd.hpp"
#include "tensor_layout.hpp"
namespace {
......@@ -295,7 +294,7 @@ void profile_convnd_instances(ConvDataType data_type,
} // namespace
int ck::profiler::profile_convnd_fwd(int argc, char* argv[])
bool profile_convnd_fwd(int argc, char* argv[])
{
using namespace ck::utils::conv;
......@@ -347,5 +346,6 @@ int ck::profiler::profile_convnd_fwd(int argc, char* argv[])
std::to_string(num_dim_spatial));
}
return 1;
// FIXME: return true if test pass, return false if test fail
return true;
}
......@@ -4,31 +4,29 @@
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "profile_gemm_impl.hpp"
enum struct GemmMatrixLayout
// return true if test pass
bool profile_gemm(int argc, char* argv[])
{
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
KM_KN_MN, // 2
KM_NK_MN, // 3
MK_KN_NM, // 4
MK_NK_NM, // 5
KM_KN_NM, // 6
KM_NK_NM, // 7
};
};
enum struct GemmDataType
{
enum struct GemmDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3
};
};
int profile_gemm(int argc, char* argv[])
{
if(!(argc == 14 || argc == 15))
if(argc != 14)
{
printf("arg1: tensor operation (gemm: GEMM)\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
......@@ -41,8 +39,7 @@ int profile_gemm(int argc, char* argv[])
printf("arg8: print tensor value (0: no; 1: yes)\n");
printf("arg7: run kernel # of times (>1)\n");
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
printf("arg14: split k into mulitiple batch\n");
exit(1);
return false;
}
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
......@@ -59,13 +56,10 @@ int profile_gemm(int argc, char* argv[])
const int StrideA = std::stoi(argv[11]);
const int StrideB = std::stoi(argv[12]);
const int StrideC = std::stoi(argv[13]);
int KBatch = 1;
if(argc == 15)
KBatch = std::stoi(argv[14]);
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_gemm_impl<ck::half_t,
return ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
......@@ -80,12 +74,11 @@ int profile_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
(StrideC < 0) ? N : StrideC);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_gemm_impl<ck::half_t,
return ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
......@@ -100,12 +93,11 @@ int profile_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
(StrideC < 0) ? N : StrideC);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_gemm_impl<ck::half_t,
return ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -120,12 +112,11 @@ int profile_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
(StrideC < 0) ? N : StrideC);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_gemm_impl<ck::half_t,
return ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -140,12 +131,11 @@ int profile_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
(StrideC < 0) ? N : StrideC);
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_gemm_impl<float,
return ck::profiler::profile_gemm_impl<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
......@@ -160,12 +150,11 @@ int profile_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
(StrideC < 0) ? N : StrideC);
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_gemm_impl<float,
return ck::profiler::profile_gemm_impl<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
......@@ -180,12 +169,11 @@ int profile_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
(StrideC < 0) ? N : StrideC);
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_gemm_impl<float,
return ck::profiler::profile_gemm_impl<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -200,12 +188,11 @@ int profile_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
(StrideC < 0) ? N : StrideC);
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_gemm_impl<float,
return ck::profiler::profile_gemm_impl<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -220,12 +207,11 @@ int profile_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
(StrideC < 0) ? N : StrideC);
}
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_gemm_impl<int8_t,
return ck::profiler::profile_gemm_impl<int8_t,
int8_t,
int8_t,
ck::tensor_layout::gemm::RowMajor,
......@@ -240,12 +226,11 @@ int profile_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
(StrideC < 0) ? N : StrideC);
}
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_gemm_impl<int8_t,
return ck::profiler::profile_gemm_impl<int8_t,
int8_t,
int8_t,
ck::tensor_layout::gemm::RowMajor,
......@@ -260,12 +245,11 @@ int profile_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
(StrideC < 0) ? N : StrideC);
}
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_gemm_impl<int8_t,
return ck::profiler::profile_gemm_impl<int8_t,
int8_t,
int8_t,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -280,12 +264,11 @@ int profile_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
(StrideC < 0) ? N : StrideC);
}
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_gemm_impl<int8_t,
return ck::profiler::profile_gemm_impl<int8_t,
int8_t,
int8_t,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -300,12 +283,11 @@ int profile_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
(StrideC < 0) ? N : StrideC);
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_gemm_impl<ck::bhalf_t,
return ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
ck::tensor_layout::gemm::RowMajor,
......@@ -320,12 +302,11 @@ int profile_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
(StrideC < 0) ? N : StrideC);
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_gemm_impl<ck::bhalf_t,
return ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
ck::tensor_layout::gemm::RowMajor,
......@@ -340,12 +321,11 @@ int profile_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
(StrideC < 0) ? N : StrideC);
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_gemm_impl<ck::bhalf_t,
return ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -360,12 +340,11 @@ int profile_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
(StrideC < 0) ? N : StrideC);
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_gemm_impl<ck::bhalf_t,
return ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -380,13 +359,12 @@ int profile_gemm(int argc, char* argv[])
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
(StrideC < 0) ? N : StrideC);
}
else
{
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
}
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
return true;
}
}
......@@ -6,8 +6,10 @@
#include <half.hpp>
#include "profile_gemm_bias_2d_impl.hpp"
enum struct GemmMatrixLayout
bool profile_gemm_bias_2d(int argc, char* argv[])
{
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
KM_KN_MN, // 2
......@@ -16,17 +18,15 @@ enum struct GemmMatrixLayout
MK_NK_NM, // 5
KM_KN_NM, // 6
KM_NK_NM, // 7
};
};
enum struct GemmDataType
{
enum struct GemmDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
};
};
int profile_gemm_bias_2d(int argc, char* argv[])
{
if(!(argc == 16 || argc == 17))
if(argc != 16)
{
printf("arg1: tensor operation (gemm: GEMM+Bias_2d)\n");
printf("arg2: data type (0: fp32; 1: fp16)\n");
......@@ -41,8 +41,7 @@ int profile_gemm_bias_2d(int argc, char* argv[])
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
printf("arg14: alpha\n");
printf("arg15: beta\n");
printf("arg16: split k into mulitiple batch\n");
exit(1);
return false;
}
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
......@@ -65,7 +64,7 @@ int profile_gemm_bias_2d(int argc, char* argv[])
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_gemm_bias_2d_impl<float,
return ck::profiler::profile_gemm_bias_2d_impl<float,
float,
float,
float,
......@@ -88,7 +87,7 @@ int profile_gemm_bias_2d(int argc, char* argv[])
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_gemm_bias_2d_impl<float,
return ck::profiler::profile_gemm_bias_2d_impl<float,
float,
float,
float,
......@@ -111,7 +110,7 @@ int profile_gemm_bias_2d(int argc, char* argv[])
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_gemm_bias_2d_impl<float,
return ck::profiler::profile_gemm_bias_2d_impl<float,
float,
float,
float,
......@@ -134,7 +133,7 @@ int profile_gemm_bias_2d(int argc, char* argv[])
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_gemm_bias_2d_impl<float,
return ck::profiler::profile_gemm_bias_2d_impl<float,
float,
float,
float,
......@@ -157,7 +156,7 @@ int profile_gemm_bias_2d(int argc, char* argv[])
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_gemm_bias_2d_impl<ck::half_t,
return ck::profiler::profile_gemm_bias_2d_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::half_t,
......@@ -180,7 +179,7 @@ int profile_gemm_bias_2d(int argc, char* argv[])
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_gemm_bias_2d_impl<ck::half_t,
return ck::profiler::profile_gemm_bias_2d_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::half_t,
......@@ -203,7 +202,7 @@ int profile_gemm_bias_2d(int argc, char* argv[])
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_gemm_bias_2d_impl<ck::half_t,
return ck::profiler::profile_gemm_bias_2d_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::half_t,
......@@ -226,7 +225,7 @@ int profile_gemm_bias_2d(int argc, char* argv[])
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_gemm_bias_2d_impl<ck::half_t,
return ck::profiler::profile_gemm_bias_2d_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::half_t,
......@@ -249,8 +248,8 @@ int profile_gemm_bias_2d(int argc, char* argv[])
}
else
{
throw std::runtime_error("wrong! this data_type & layout is not implemented");
}
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
return true;
}
}
......@@ -6,8 +6,10 @@
#include <half.hpp>
#include "profile_gemm_bias_relu_impl.hpp"
enum struct GemmMatrixLayout
bool profile_gemm_bias_relu(int argc, char* argv[])
{
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
KM_KN_MN, // 2
......@@ -16,16 +18,14 @@ enum struct GemmMatrixLayout
MK_NK_NM, // 5
KM_KN_NM, // 6
KM_NK_NM, // 7
};
};
enum struct GemmDataType
{
enum struct GemmDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
};
};
int profile_gemm_bias_relu(int argc, char* argv[])
{
if(!(argc == 14 || argc == 15))
{
printf("arg1: tensor operation (gemm: GEMM+Bias+ReLU)\n");
......@@ -40,7 +40,7 @@ int profile_gemm_bias_relu(int argc, char* argv[])
printf("arg7: run kernel # of times (>1)\n");
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
printf("arg14: split k into mulitiple batch\n");
exit(1);
return false;
}
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
......@@ -60,7 +60,7 @@ int profile_gemm_bias_relu(int argc, char* argv[])
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_gemm_bias_relu_impl<ck::half_t,
return ck::profiler::profile_gemm_bias_relu_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
......@@ -79,7 +79,7 @@ int profile_gemm_bias_relu(int argc, char* argv[])
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_gemm_bias_relu_impl<ck::half_t,
return ck::profiler::profile_gemm_bias_relu_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
......@@ -98,7 +98,7 @@ int profile_gemm_bias_relu(int argc, char* argv[])
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_gemm_bias_relu_impl<ck::half_t,
return ck::profiler::profile_gemm_bias_relu_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -117,7 +117,7 @@ int profile_gemm_bias_relu(int argc, char* argv[])
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_gemm_bias_relu_impl<ck::half_t,
return ck::profiler::profile_gemm_bias_relu_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -136,8 +136,8 @@ int profile_gemm_bias_relu(int argc, char* argv[])
}
else
{
throw std::runtime_error("wrong! this data_type & layout is not implemented");
}
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
return true;
}
}
......@@ -6,8 +6,10 @@
#include <half.hpp>
#include "profile_gemm_bias_relu_add_impl.hpp"
enum struct GemmMatrixLayout
bool profile_gemm_bias_relu_add(int argc, char* argv[])
{
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
KM_KN_MN, // 2
......@@ -16,17 +18,15 @@ enum struct GemmMatrixLayout
MK_NK_NM, // 5
KM_KN_NM, // 6
KM_NK_NM, // 7
};
};
enum struct GemmDataType
{
enum struct GemmDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
};
};
int profile_gemm_bias_relu_add(int argc, char* argv[])
{
if(!(argc == 15 || argc == 16))
if(argc != 15)
{
printf("arg1: tensor operation (gemm: GEMM+Bias+ReLU+Add)\n");
printf("arg2: data type (0: fp32; 1: fp16)\n");
......@@ -39,8 +39,7 @@ int profile_gemm_bias_relu_add(int argc, char* argv[])
printf("arg8: print tensor value (0: no; 1: yes)\n");
printf("arg7: run kernel # of times (>1)\n");
printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, StrideC1\n");
printf("arg15: split k into mulitiple batch\n");
exit(1);
return false;
}
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
......@@ -61,7 +60,7 @@ int profile_gemm_bias_relu_add(int argc, char* argv[])
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_gemm_bias_relu_add_impl<ck::half_t,
return ck::profiler::profile_gemm_bias_relu_add_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
......@@ -81,7 +80,7 @@ int profile_gemm_bias_relu_add(int argc, char* argv[])
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_gemm_bias_relu_add_impl<ck::half_t,
return ck::profiler::profile_gemm_bias_relu_add_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
......@@ -101,7 +100,7 @@ int profile_gemm_bias_relu_add(int argc, char* argv[])
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_gemm_bias_relu_add_impl<ck::half_t,
return ck::profiler::profile_gemm_bias_relu_add_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -121,7 +120,7 @@ int profile_gemm_bias_relu_add(int argc, char* argv[])
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_gemm_bias_relu_add_impl<ck::half_t,
return ck::profiler::profile_gemm_bias_relu_add_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -141,8 +140,8 @@ int profile_gemm_bias_relu_add(int argc, char* argv[])
}
else
{
throw std::runtime_error("wrong! this data_type & layout is not implemented");
}
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
return true;
}
}
......@@ -6,7 +6,8 @@
#include <half.hpp>
#include "profile_gemm_reduce_impl.hpp"
int profile_gemm_reduce(int argc, char* argv[])
// return true if test pass
bool profile_gemm_reduce(int argc, char* argv[])
{
enum struct GemmMatrixLayout
{
......@@ -22,7 +23,7 @@ int profile_gemm_reduce(int argc, char* argv[])
F16_F16_F16_F32_F32, // 1
};
if(!(argc == 14 || argc == 15))
if(argc != 14)
{
printf("arg1: tensor operation (gemm: GEMM+Reduce)\n");
printf("arg2: data type (0: fp32; 1: fp16)\n");
......@@ -34,9 +35,7 @@ int profile_gemm_reduce(int argc, char* argv[])
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg8: print tensor value (0: no; 1: yes)\n");
printf("arg7: run kernel # of times (>1)\n");
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
printf("arg14: split k into mulitiple batch\n");
exit(1);
return false;
}
const auto data_type = static_cast<GemmReduceDataType>(std::stoi(argv[2]));
......@@ -56,7 +55,7 @@ int profile_gemm_reduce(int argc, char* argv[])
if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_gemm_reduce_impl<ck::half_t,
return ck::profiler::profile_gemm_reduce_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
......@@ -77,7 +76,7 @@ int profile_gemm_reduce(int argc, char* argv[])
else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 &&
layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_gemm_reduce_impl<ck::half_t,
return ck::profiler::profile_gemm_reduce_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
......@@ -98,7 +97,7 @@ int profile_gemm_reduce(int argc, char* argv[])
else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 &&
layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_gemm_reduce_impl<ck::half_t,
return ck::profiler::profile_gemm_reduce_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
......@@ -119,7 +118,7 @@ int profile_gemm_reduce(int argc, char* argv[])
else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 &&
layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_gemm_reduce_impl<ck::half_t,
return ck::profiler::profile_gemm_reduce_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
......@@ -139,8 +138,8 @@ int profile_gemm_reduce(int argc, char* argv[])
}
else
{
throw std::runtime_error("wrong! this data_type & layout is not implemented");
}
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
return true;
}
}
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "profile_gemm_splitk_impl.hpp"
// return true if test pass
bool profile_gemm_splitk(int argc, char* argv[])
{
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
KM_KN_MN, // 2
KM_NK_MN, // 3
};
enum struct GemmDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
};
if(argc != 15)
{
printf("arg1: tensor operation (gemm: GEMM)\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
printf("arg4: verification (0: no; 1: yes)\n");
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg8: print tensor value (0: no; 1: yes)\n");
printf("arg7: run kernel # of times (>1)\n");
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
printf("arg14: split k into mulitiple batch\n");
return false;
}
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
const int nrepeat = std::stoi(argv[7]);
const int M = std::stoi(argv[8]);
const int N = std::stoi(argv[9]);
const int K = std::stoi(argv[10]);
const int StrideA = std::stoi(argv[11]);
const int StrideB = std::stoi(argv[12]);
const int StrideC = std::stoi(argv[13]);
const int KBatch = std::stoi(argv[14]);
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return ck::profiler::profile_gemm_splitk_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return ck::profiler::profile_gemm_splitk_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{
return ck::profiler::profile_gemm_splitk_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
return ck::profiler::profile_gemm_splitk_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{
return ck::profiler::profile_gemm_splitk_impl<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{
return ck::profiler::profile_gemm_splitk_impl<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{
return ck::profiler::profile_gemm_splitk_impl<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{
return ck::profiler::profile_gemm_splitk_impl<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
return true;
}
}
......@@ -6,26 +6,6 @@
#include <half.hpp>
#include "profile_grouped_gemm_impl.hpp"
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
KM_KN_MN, // 2
KM_NK_MN, // 3
MK_KN_NM, // 4
MK_NK_NM, // 5
KM_KN_NM, // 6
KM_NK_NM, // 7
};
enum struct GemmDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3
};
std::vector<int> argToIntArray(char* input)
{
std::vector<int> out;
......@@ -42,9 +22,25 @@ std::vector<int> argToIntArray(char* input)
return out;
}
int profile_grouped_gemm(int argc, char* argv[])
bool profile_grouped_gemm(int argc, char* argv[])
{
if(!(argc == 14))
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
KM_KN_MN, // 2
KM_NK_MN, // 3
};
enum struct GemmDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3
};
if(argc != 14)
{
printf("arg1: tensor operation (grouped_gemm: Grouped GEMM)\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
......@@ -58,7 +54,7 @@ int profile_grouped_gemm(int argc, char* argv[])
printf("arg7: run kernel # of times (>1)\n");
printf("arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)\n");
exit(1);
return false;
}
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
......@@ -78,12 +74,13 @@ int profile_grouped_gemm(int argc, char* argv[])
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_impl<ck::half_t,
return ck::profiler::profile_grouped_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(do_verification,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
......@@ -96,12 +93,13 @@ int profile_grouped_gemm(int argc, char* argv[])
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_impl<ck::half_t,
return ck::profiler::profile_grouped_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(do_verification,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
......@@ -114,12 +112,13 @@ int profile_grouped_gemm(int argc, char* argv[])
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_grouped_gemm_impl<ck::half_t,
return ck::profiler::profile_grouped_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(do_verification,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
......@@ -132,12 +131,13 @@ int profile_grouped_gemm(int argc, char* argv[])
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_grouped_gemm_impl<ck::half_t,
return ck::profiler::profile_grouped_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(do_verification,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
......@@ -150,8 +150,8 @@ int profile_grouped_gemm(int argc, char* argv[])
}
else
{
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
}
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
return true;
}
}
......@@ -320,7 +320,7 @@ class AppArgs
}; // end of class AppArgs
int profile_reduce(int argc, char* argv[])
bool profile_reduce(int argc, char* argv[])
{
using namespace ck::profiler;
......@@ -499,5 +499,5 @@ int profile_reduce(int argc, char* argv[])
throw std::runtime_error("Invalid compType assignment!");
};
return (0);
return true;
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment