Commit 0e67221f authored by Chao Liu's avatar Chao Liu
Browse files

format

parent fe027ba3
......@@ -21,18 +21,19 @@ enum GemmMatrixLayout
KM_KN_MN, // 2
KM_NK_MN, // 3
};
using DeviceGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
using GEMM_PTR = std::vector<DeviceGemmNoOpPtr>;
static std::vector<std::vector<bool>>& GetLayoutType()
{
static std::vector<std::vector<bool>> LayOut = {{0, 0, 0}, {0, 1, 0}, {1, 0, 0}, {1, 1, 0}};
return LayOut;
}
static void add_device_gemm_instance_mk_kn_mn(GEMM_PTR& gemm_ptrs)
static void add_device_gemm_instance_mk_kn_mn(std::vector<DeviceGemmNoOpPtr>& gemm_ptrs)
{
ck::tensor_operation::device::device_gemm_instance::add_device_splitk_gemm_instance<
float,
......@@ -42,7 +43,8 @@ static void add_device_gemm_instance_mk_kn_mn(GEMM_PTR& gemm_ptrs)
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(gemm_ptrs);
}
static void add_device_gemm_instance_mk_nk_mn(GEMM_PTR& gemm_ptrs)
static void add_device_gemm_instance_mk_nk_mn(std::vector<DeviceGemmNoOpPtr>& gemm_ptrs)
{
ck::tensor_operation::device::device_gemm_instance::add_device_splitk_gemm_instance<
float,
......@@ -52,7 +54,7 @@ static void add_device_gemm_instance_mk_nk_mn(GEMM_PTR& gemm_ptrs)
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(gemm_ptrs);
}
static void add_device_gemm_instance_km_kn_mn(GEMM_PTR& gemm_ptrs)
static void add_device_gemm_instance_km_kn_mn(std::vector<DeviceGemmNoOpPtr>& gemm_ptrs)
{
ck::tensor_operation::device::device_gemm_instance::add_device_splitk_gemm_instance<
float,
......@@ -62,7 +64,7 @@ static void add_device_gemm_instance_km_kn_mn(GEMM_PTR& gemm_ptrs)
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(gemm_ptrs);
}
static void add_device_gemm_instance_km_nk_mn(GEMM_PTR& gemm_ptrs)
static void add_device_gemm_instance_km_nk_mn(std::vector<DeviceGemmNoOpPtr>& gemm_ptrs)
{
ck::tensor_operation::device::device_gemm_instance::add_device_splitk_gemm_instance<
float,
......@@ -75,7 +77,7 @@ static void add_device_gemm_instance_km_nk_mn(GEMM_PTR& gemm_ptrs)
static auto& GetAddDeviceGemmInstance()
{
static std::vector<void (*)(GEMM_PTR&)> AddDeviceGemmInstance = {
static std::vector<void (*)(std::vector<DeviceGemmNoOpPtr>&)> AddDeviceGemmInstance = {
add_device_gemm_instance_mk_kn_mn,
add_device_gemm_instance_mk_nk_mn,
add_device_gemm_instance_km_kn_mn,
......@@ -83,7 +85,7 @@ static auto& GetAddDeviceGemmInstance()
return AddDeviceGemmInstance;
}
static void add_device_gemm_instance(GEMM_PTR& gemm_ptrs, int layout)
static void add_device_gemm_instance(std::vector<DeviceGemmNoOpPtr>& gemm_ptrs, int layout)
{
GetAddDeviceGemmInstance()[layout](gemm_ptrs);
}
......@@ -104,6 +106,7 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
return true;
}
int main(int argc, char* argv[])
{
if(argc != 9)
......@@ -175,7 +178,7 @@ int main(int argc, char* argv[])
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
// add device GEMM instances
GEMM_PTR gemm_ptrs;
std::vector<DeviceGemmNoOpPtr> gemm_ptrs;
add_device_gemm_instance(gemm_ptrs, layout);
bool success = false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment