Commit 14bd1430 authored by myamlak's avatar myamlak
Browse files

Sketch of tests

parent 6a0883fa
......@@ -60,3 +60,4 @@ add_subdirectory(grouped_gemm)
add_subdirectory(convnd_fwd)
add_subdirectory(reduce)
add_subdirectory(conv2d_bwd_weight)
add_subdirectory(cgemm)
add_test_executable(test_cgemm_fp32 cgemm_fp32.cpp)
target_link_libraries(test_cgemm_fp32 PRIVATE host_tensor)
target_link_libraries(test_cgemm_fp32 PRIVATE device_cgemm_instance)
add_test_executable(test_cgemm_fp16 cgemm_fp16.cpp)
target_link_libraries(test_cgemm_fp16 PRIVATE host_tensor)
target_link_libraries(test_cgemm_fp16 PRIVATE device_cgemm_instance)
add_test_executable(test_cgemm_bf16 cgemm_bf16.cpp)
target_link_libraries(test_cgemm_bf16 PRIVATE host_tensor)
target_link_libraries(test_cgemm_bf16 PRIVATE device_cgemm_instance)
#include <algorithm>
#include <cstdlib>
#include <half.hpp>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "cgemm_util.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "device_cgemm_4gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reference_cgemm.hpp"
#include "gemm_specialization.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceCGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_cgemm_instance {
void add_device_cgemm_4gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(
std::vector<DeviceCGemmNoOpPtr>&);
void add_device_cgemm_4gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(
std::vector<DeviceCGemmNoOpPtr>&);
void add_device_cgemm_4gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(
std::vector<DeviceCGemmNoOpPtr>&);
void add_device_cgemm_4gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(
std::vector<DeviceCGemmNoOpPtr>&);
} // namespace device_cgemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
int main()
{
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
bool res = true;
std::vector<DeviceCGemmNoOpPtr> gemmPtrs;
ck::tensor_operation::device::device_gemm_instance::
add_device_cgemm_4gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(cgemmPtrs);
for(auto& cgemmPtr : cgemmPtrs)
{
res &= ck::cgemm_util::TestCGemmBF16<DeviceCGemmNoOpPtr,
ColumnMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(cgemmPtr);
}
cgemmPtrs.clear();
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(cgemmPtrs);
for(auto& cgemmPtr : cgemmPtrs)
{
res &= ck::cgemm_util::TestCGemmBF16<DeviceCGemmNoOpPtr,
ColumnMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
cgemmPtrs.clear();
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(cgemmPtrs);
for(auto& cgemmPtr : cgemmPtrs)
{
res &= ck::cgemm_util::TestCGemmBF16<DeviceCGemmNoOpPtr,
RowMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(cgemmPtr);
}
cgemmPtrs.clear();
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(cgemmPtrs);
for(auto& cgemmPtr : cgemmPtrs)
{
res &= ck::cgemm_util::TestCGemmBF16<DeviceCGemmNoOpPtr,
RowMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(cgemmPtr);
}
std::cout << "TestCGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res ? 0 : 1;
}
#include <algorithm>
#include <cstdlib>
#include <half.hpp>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "cgemm_util.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "device_tensor.hpp"
#include "device_cgemm_4gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "gemm_specialization.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceCGemmNoOpPtr =
ck::tensor_operation::device::DevicecgemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_cgemm_instance {
void add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
std::vector<DeviceCGemmNoOpPtr>&);
void add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceCGemmNoOpPtr>&);
void add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceCGemmNoOpPtr>&);
void add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceCGemmNoOpPtr>&);
} // namespace device_cgemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
int main()
{
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
bool res = true;
std::vector<DeviceCGemmNoOpPtr> cgemmPtrs;
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_f16_f16_f16_km_kn_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(cgemmPtrs);
for(auto& cgemmPtr : cgemmPtrs)
{
res &= ck::cgemm_util::TestCGemm<DeviceCGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
ColumnMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(cgemmPtr);
}
cgemmPtrs.clear();
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_f16_f16_f16_km_nk_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(cgemmPtrs);
for(auto& cgemmPtr : cgemmPtrs)
{
res &= ck::cgemm_util::TestCGemm<DeviceCGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(cgemmPtr);
}
cgemmPtrs.clear();
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_f16_f16_f16_mk_kn_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(cgemmPtrs);
for(auto& cgemmPtr : cgemmPtrs)
{
res &= ck::cgemm_util::TestCGemm<DeviceCGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
RowMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(cgemmPtr);
}
cgemmPtrs.clear();
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_f16_f16_f16_mk_nk_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(cgemmPtrs);
for(auto& cgemmPtr : cgemmPtrs)
{
res &= ck::cgemm_util::TestCGemm<DeviceCGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
RowMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(cgemmPtr);
}
std::cout << "TestCGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res ? 0 : 1;
}
#include <algorithm>
#include <cstdlib>
#include <half.hpp>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "cgemm_util.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_cgemm_4gemm_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reference_cgemm.hpp"
#include "gemm_specialization.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceCGemmNoOpPtr =
ck::tensor_operation::device::DevicecgemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_cgemm_instance {
void add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(
std::vector<DeviceCGemmNoOpPtr>&);
void add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(
std::vector<DeviceCGemmNoOpPtr>&);
void add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(
std::vector<DeviceCGemmNoOpPtr>&);
void add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(
std::vector<DeviceCGemmNoOpPtr>&);
} // namespace device_cgemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
int main()
{
using ADataType = float;
using BDataType = float;
using CDataType = float;
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
bool res = true;
std::vector<DeviceCGemmNoOpPtr> cgemmPtrs;
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_f32_f32_f32_km_kn_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(cgemmPtrs);
for(auto& cgemmPtr : cgemmPtrs)
{
res &= ck::cgemm_util::TestCGemm<DeviceCGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
ColumnMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(cgemmPtr);
}
cgemmPtrs.clear();
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_f32_f32_f32_km_nk_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(cgemmPtrs);
for(auto& cgemmPtr : cgemmPtrs)
{
res &= ck::cgemm_util::TestCGemm<DeviceCGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(cgemmPtr);
}
cgemmPtrs.clear();
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_f32_f32_f32_mk_kn_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(cgemmPtrs);
for(auto& cgemmPtr : cgemmPtrs)
{
res &= ck::cgemm_util::TestCGemm<DeviceCGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
RowMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(cgemmPtr);
}
cgemmPtrs.clear();
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_f32_f32_f32_mk_nk_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(cgemmPtrs);
ck::tensor_operation::device::device_cgemm_instance::
add_device_cgemm_4gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(cgemmPtrs);
for(auto& cgemmPtr : cgemmPtrs)
{
res &= ck::cgemm_util::TestCGemm<DeviceCGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
RowMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(cgemmPtr);
}
std::cout << "TestCGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res ? 0 : 1;
}
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment