Commit 0e67221f authored by Chao Liu's avatar Chao Liu
Browse files

format

parent fe027ba3
...@@ -21,18 +21,19 @@ enum GemmMatrixLayout ...@@ -21,18 +21,19 @@ enum GemmMatrixLayout
KM_KN_MN, // 2 KM_KN_MN, // 2
KM_NK_MN, // 3 KM_NK_MN, // 3
}; };
using DeviceGemmNoOpPtr = using DeviceGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>; ck::tensor_operation::element_wise::PassThrough>;
using GEMM_PTR = std::vector<DeviceGemmNoOpPtr>;
static std::vector<std::vector<bool>>& GetLayoutType() static std::vector<std::vector<bool>>& GetLayoutType()
{ {
static std::vector<std::vector<bool>> LayOut = {{0, 0, 0}, {0, 1, 0}, {1, 0, 0}, {1, 1, 0}}; static std::vector<std::vector<bool>> LayOut = {{0, 0, 0}, {0, 1, 0}, {1, 0, 0}, {1, 1, 0}};
return LayOut; return LayOut;
} }
static void add_device_gemm_instance_mk_kn_mn(GEMM_PTR& gemm_ptrs)
static void add_device_gemm_instance_mk_kn_mn(std::vector<DeviceGemmNoOpPtr>& gemm_ptrs)
{ {
ck::tensor_operation::device::device_gemm_instance::add_device_splitk_gemm_instance< ck::tensor_operation::device::device_gemm_instance::add_device_splitk_gemm_instance<
float, float,
...@@ -42,7 +43,8 @@ static void add_device_gemm_instance_mk_kn_mn(GEMM_PTR& gemm_ptrs) ...@@ -42,7 +43,8 @@ static void add_device_gemm_instance_mk_kn_mn(GEMM_PTR& gemm_ptrs)
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(gemm_ptrs); ck::tensor_layout::gemm::RowMajor>(gemm_ptrs);
} }
static void add_device_gemm_instance_mk_nk_mn(GEMM_PTR& gemm_ptrs)
static void add_device_gemm_instance_mk_nk_mn(std::vector<DeviceGemmNoOpPtr>& gemm_ptrs)
{ {
ck::tensor_operation::device::device_gemm_instance::add_device_splitk_gemm_instance< ck::tensor_operation::device::device_gemm_instance::add_device_splitk_gemm_instance<
float, float,
...@@ -52,7 +54,7 @@ static void add_device_gemm_instance_mk_nk_mn(GEMM_PTR& gemm_ptrs) ...@@ -52,7 +54,7 @@ static void add_device_gemm_instance_mk_nk_mn(GEMM_PTR& gemm_ptrs)
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(gemm_ptrs); ck::tensor_layout::gemm::RowMajor>(gemm_ptrs);
} }
static void add_device_gemm_instance_km_kn_mn(GEMM_PTR& gemm_ptrs) static void add_device_gemm_instance_km_kn_mn(std::vector<DeviceGemmNoOpPtr>& gemm_ptrs)
{ {
ck::tensor_operation::device::device_gemm_instance::add_device_splitk_gemm_instance< ck::tensor_operation::device::device_gemm_instance::add_device_splitk_gemm_instance<
float, float,
...@@ -62,7 +64,7 @@ static void add_device_gemm_instance_km_kn_mn(GEMM_PTR& gemm_ptrs) ...@@ -62,7 +64,7 @@ static void add_device_gemm_instance_km_kn_mn(GEMM_PTR& gemm_ptrs)
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(gemm_ptrs); ck::tensor_layout::gemm::RowMajor>(gemm_ptrs);
} }
static void add_device_gemm_instance_km_nk_mn(GEMM_PTR& gemm_ptrs) static void add_device_gemm_instance_km_nk_mn(std::vector<DeviceGemmNoOpPtr>& gemm_ptrs)
{ {
ck::tensor_operation::device::device_gemm_instance::add_device_splitk_gemm_instance< ck::tensor_operation::device::device_gemm_instance::add_device_splitk_gemm_instance<
float, float,
...@@ -75,7 +77,7 @@ static void add_device_gemm_instance_km_nk_mn(GEMM_PTR& gemm_ptrs) ...@@ -75,7 +77,7 @@ static void add_device_gemm_instance_km_nk_mn(GEMM_PTR& gemm_ptrs)
static auto& GetAddDeviceGemmInstance() static auto& GetAddDeviceGemmInstance()
{ {
static std::vector<void (*)(GEMM_PTR&)> AddDeviceGemmInstance = { static std::vector<void (*)(std::vector<DeviceGemmNoOpPtr>&)> AddDeviceGemmInstance = {
add_device_gemm_instance_mk_kn_mn, add_device_gemm_instance_mk_kn_mn,
add_device_gemm_instance_mk_nk_mn, add_device_gemm_instance_mk_nk_mn,
add_device_gemm_instance_km_kn_mn, add_device_gemm_instance_km_kn_mn,
...@@ -83,7 +85,7 @@ static auto& GetAddDeviceGemmInstance() ...@@ -83,7 +85,7 @@ static auto& GetAddDeviceGemmInstance()
return AddDeviceGemmInstance; return AddDeviceGemmInstance;
} }
static void add_device_gemm_instance(GEMM_PTR& gemm_ptrs, int layout) static void add_device_gemm_instance(std::vector<DeviceGemmNoOpPtr>& gemm_ptrs, int layout)
{ {
GetAddDeviceGemmInstance()[layout](gemm_ptrs); GetAddDeviceGemmInstance()[layout](gemm_ptrs);
} }
...@@ -104,6 +106,7 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result) ...@@ -104,6 +106,7 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
return true; return true;
} }
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
if(argc != 9) if(argc != 9)
...@@ -175,7 +178,7 @@ int main(int argc, char* argv[]) ...@@ -175,7 +178,7 @@ int main(int argc, char* argv[])
c_device_buf.ToDevice(c_m_n_device_result.mData.data()); c_device_buf.ToDevice(c_m_n_device_result.mData.data());
// add device GEMM instances // add device GEMM instances
GEMM_PTR gemm_ptrs; std::vector<DeviceGemmNoOpPtr> gemm_ptrs;
add_device_gemm_instance(gemm_ptrs, layout); add_device_gemm_instance(gemm_ptrs, layout);
bool success = false; bool success = false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment