Commit 8160c31a authored by Chao Liu's avatar Chao Liu
Browse files

clean up

parent 0e67221f
#pragma once
#include "device_gemm_instance.hpp"
#include "device_gemm_splitk_xdl_instance.hpp"
//#include "device_gemm_instance.hpp"
//#include "device_gemm_splitk_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
#if 0
template <>
void add_device_gemm_instance<float,
float,
......@@ -70,6 +71,22 @@ void add_device_gemm_instance<ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
#else
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
#endif
} // namespace device_gemm_instance
} // namespace device
......
......@@ -11,15 +11,13 @@ include_directories(BEFORE
${PROJECT_SOURCE_DIR}/external/rocm/include
)
# test_magic_number_division
set(MAGIC_NUMBER_DIVISISON_SOURCE magic_number_division/main.cpp)
add_executable(test_magic_number_division ${MAGIC_NUMBER_DIVISISON_SOURCE})
target_link_libraries(test_magic_number_division PRIVATE host_tensor)
# test_split_k
set(SPLIT_K_SOURCE split_k/main.cpp)
add_executable(test_split_k ${SPLIT_K_SOURCE})
target_link_libraries(test_split_k PRIVATE host_tensor)
target_link_libraries(test_split_k PRIVATE device_gemm_instance)
\ No newline at end of file
target_link_libraries(test_split_k PRIVATE device_gemm_instance)
......@@ -8,11 +8,9 @@
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "device_gemm_instance.hpp"
#include "host_gemm.hpp"
#include "tensor_layout.hpp"
#include "device_gemm_splitk_xdl_instance.hpp"
#include "device_gemm_splitk_xdl.hpp"
#include "device_gemm_xdl_splitk.hpp"
enum GemmMatrixLayout
{
......@@ -33,6 +31,7 @@ static std::vector<std::vector<bool>>& GetLayoutType()
return LayOut;
}
#if 0
static void add_device_gemm_instance_mk_kn_mn(std::vector<DeviceGemmNoOpPtr>& gemm_ptrs)
{
ck::tensor_operation::device::device_gemm_instance::add_device_splitk_gemm_instance<
......@@ -84,10 +83,23 @@ static auto& GetAddDeviceGemmInstance()
add_device_gemm_instance_km_nk_mn};
return AddDeviceGemmInstance;
}
#else
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
#endif
static void add_device_gemm_instance(std::vector<DeviceGemmNoOpPtr>& gemm_ptrs, int layout)
{
#if 0
GetAddDeviceGemmInstance()[layout](gemm_ptrs);
#else
if(layout == 2)
{
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
}
#endif
}
template <typename T>
......@@ -150,6 +162,7 @@ int main(int argc, char* argv[])
std::vector<std::size_t>({stride, 1}));
}
};
Tensor<float> a_m_k(f_host_tensor_descriptor(M, K, StrideA, LayOut[layout][0]));
Tensor<float> b_k_n(f_host_tensor_descriptor(K, N, StrideB, LayOut[layout][1]));
Tensor<float> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, LayOut[layout][2]));
......@@ -213,6 +226,7 @@ int main(int argc, char* argv[])
success = true;
}
}
if(success)
{
std::cout << "test split k : Pass" << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment