Commit 8160c31a authored by Chao Liu's avatar Chao Liu
Browse files

clean up

parent 0e67221f
#pragma once #pragma once
#include "device_gemm_instance.hpp" //#include "device_gemm_instance.hpp"
#include "device_gemm_splitk_xdl_instance.hpp" //#include "device_gemm_splitk_xdl_instance.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_gemm_instance { namespace device_gemm_instance {
#if 0
template <> template <>
void add_device_gemm_instance<float, void add_device_gemm_instance<float,
float, float,
...@@ -70,6 +71,22 @@ void add_device_gemm_instance<ck::half_t, ...@@ -70,6 +71,22 @@ void add_device_gemm_instance<ck::half_t,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&); ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
#else
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
#endif
} // namespace device_gemm_instance } // namespace device_gemm_instance
} // namespace device } // namespace device
......
...@@ -11,15 +11,13 @@ include_directories(BEFORE ...@@ -11,15 +11,13 @@ include_directories(BEFORE
${PROJECT_SOURCE_DIR}/external/rocm/include ${PROJECT_SOURCE_DIR}/external/rocm/include
) )
# test_magic_number_division
set(MAGIC_NUMBER_DIVISISON_SOURCE magic_number_division/main.cpp) set(MAGIC_NUMBER_DIVISISON_SOURCE magic_number_division/main.cpp)
add_executable(test_magic_number_division ${MAGIC_NUMBER_DIVISISON_SOURCE}) add_executable(test_magic_number_division ${MAGIC_NUMBER_DIVISISON_SOURCE})
target_link_libraries(test_magic_number_division PRIVATE host_tensor) target_link_libraries(test_magic_number_division PRIVATE host_tensor)
# test_split_k
set(SPLIT_K_SOURCE split_k/main.cpp) set(SPLIT_K_SOURCE split_k/main.cpp)
add_executable(test_split_k ${SPLIT_K_SOURCE}) add_executable(test_split_k ${SPLIT_K_SOURCE})
target_link_libraries(test_split_k PRIVATE host_tensor) target_link_libraries(test_split_k PRIVATE host_tensor)
target_link_libraries(test_split_k PRIVATE device_gemm_instance) target_link_libraries(test_split_k PRIVATE device_gemm_instance)
\ No newline at end of file
...@@ -8,11 +8,9 @@ ...@@ -8,11 +8,9 @@
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_instance.hpp"
#include "host_gemm.hpp" #include "host_gemm.hpp"
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "device_gemm_splitk_xdl_instance.hpp" #include "device_gemm_xdl_splitk.hpp"
#include "device_gemm_splitk_xdl.hpp"
enum GemmMatrixLayout enum GemmMatrixLayout
{ {
...@@ -33,6 +31,7 @@ static std::vector<std::vector<bool>>& GetLayoutType() ...@@ -33,6 +31,7 @@ static std::vector<std::vector<bool>>& GetLayoutType()
return LayOut; return LayOut;
} }
#if 0
static void add_device_gemm_instance_mk_kn_mn(std::vector<DeviceGemmNoOpPtr>& gemm_ptrs) static void add_device_gemm_instance_mk_kn_mn(std::vector<DeviceGemmNoOpPtr>& gemm_ptrs)
{ {
ck::tensor_operation::device::device_gemm_instance::add_device_splitk_gemm_instance< ck::tensor_operation::device::device_gemm_instance::add_device_splitk_gemm_instance<
...@@ -84,10 +83,23 @@ static auto& GetAddDeviceGemmInstance() ...@@ -84,10 +83,23 @@ static auto& GetAddDeviceGemmInstance()
add_device_gemm_instance_km_nk_mn}; add_device_gemm_instance_km_nk_mn};
return AddDeviceGemmInstance; return AddDeviceGemmInstance;
} }
#else
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
#endif
static void add_device_gemm_instance(std::vector<DeviceGemmNoOpPtr>& gemm_ptrs, int layout) static void add_device_gemm_instance(std::vector<DeviceGemmNoOpPtr>& gemm_ptrs, int layout)
{ {
#if 0
GetAddDeviceGemmInstance()[layout](gemm_ptrs); GetAddDeviceGemmInstance()[layout](gemm_ptrs);
#else
if(layout == 2)
{
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
}
#endif
} }
template <typename T> template <typename T>
...@@ -150,6 +162,7 @@ int main(int argc, char* argv[]) ...@@ -150,6 +162,7 @@ int main(int argc, char* argv[])
std::vector<std::size_t>({stride, 1})); std::vector<std::size_t>({stride, 1}));
} }
}; };
Tensor<float> a_m_k(f_host_tensor_descriptor(M, K, StrideA, LayOut[layout][0])); Tensor<float> a_m_k(f_host_tensor_descriptor(M, K, StrideA, LayOut[layout][0]));
Tensor<float> b_k_n(f_host_tensor_descriptor(K, N, StrideB, LayOut[layout][1])); Tensor<float> b_k_n(f_host_tensor_descriptor(K, N, StrideB, LayOut[layout][1]));
Tensor<float> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, LayOut[layout][2])); Tensor<float> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, LayOut[layout][2]));
...@@ -213,6 +226,7 @@ int main(int argc, char* argv[]) ...@@ -213,6 +226,7 @@ int main(int argc, char* argv[])
success = true; success = true;
} }
} }
if(success) if(success)
{ {
std::cout << "test split k : Pass" << std::endl; std::cout << "test split k : Pass" << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment