Commit 8e0beb65 authored by Alan Turner's avatar Alan Turner
Browse files

Add unit tests

parent fcca3307
......@@ -249,6 +249,8 @@ include_directories(BEFORE
${HIP_INCLUDE_DIRS}
)
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
if (NOT CK_BUILD_JIT_LIB)
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV)
......@@ -257,7 +259,7 @@ if (NOT CK_BUILD_JIT_LIB)
endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp")
file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*)
......@@ -286,7 +288,6 @@ if (NOT CK_BUILD_JIT_LIB)
add_subdirectory(example)
add_subdirectory(test)
add_subdirectory(profiler)
else()
......@@ -297,7 +298,9 @@ else()
endif()
add_subdirectory(library)
add_subdirectory(test)
#Create an interface target for the include only files and call it "composablekernels"
include(CMakePackageConfigHelpers)
......
......@@ -49,7 +49,7 @@ std::vector<std::string> Problem::GetInstances(const std::string& arch) const
const bool quantize = ADataType == DataType::Int8 and BDataType == DataType::Int8;
if (get_xdlop_archs().find(arch) != get_xdlop_archs().end())
{
instance::gemm_add_add_fastgelu_instances all_instances{};
ck::tensor_operation::device::instance::gemm_add_add_fastgelu_instances all_instances{};
if(TransA and TransB)
instances = all_instances.get_col_col_instances(quantize);
else if(TransA and not TransB)
......@@ -139,7 +139,7 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
std::string Problem::GetIncludeHeader() const
{
return instance::gemm_add_add_fastgelu_instances{}.get_include_header();
return ck::tensor_operation::device::instance::gemm_add_add_fastgelu_instances{}.get_include_header();
}
std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
......
......@@ -31,33 +31,37 @@ function(add_gtest_executable TEST_NAME)
rocm_install(TARGETS ${TEST_NAME} COMPONENT tests)
endfunction(add_gtest_executable TEST_NAME)
add_subdirectory(magic_number_division)
add_subdirectory(space_filling_curve)
add_subdirectory(conv_util)
add_subdirectory(reference_conv_fwd)
add_subdirectory(gemm)
add_subdirectory(gemm_layernorm)
add_subdirectory(gemm_split_k)
add_subdirectory(gemm_reduce)
add_subdirectory(batched_gemm)
add_subdirectory(batched_gemm_reduce)
add_subdirectory(batched_gemm_gemm)
add_subdirectory(batched_gemm_softmax_gemm)
add_subdirectory(batched_gemm_softmax_gemm_permute)
add_subdirectory(grouped_gemm)
add_subdirectory(reduce)
add_subdirectory(convnd_fwd)
add_subdirectory(convnd_bwd_data)
add_subdirectory(grouped_convnd_fwd)
add_subdirectory(grouped_convnd_bwd_weight)
add_subdirectory(block_to_ctile_map)
add_subdirectory(softmax)
add_subdirectory(normalization)
add_subdirectory(data_type)
add_subdirectory(elementwise_normalization)
add_subdirectory(batchnorm)
add_subdirectory(contraction)
add_subdirectory(pool_fwd)
if(CK_BUILD_JIT_LIB)
add_subdirectory(jit_library)
else()
add_subdirectory(magic_number_division)
add_subdirectory(space_filling_curve)
add_subdirectory(conv_util)
add_subdirectory(reference_conv_fwd)
add_subdirectory(gemm)
add_subdirectory(gemm_layernorm)
add_subdirectory(gemm_split_k)
add_subdirectory(gemm_reduce)
add_subdirectory(batched_gemm)
add_subdirectory(batched_gemm_reduce)
add_subdirectory(batched_gemm_gemm)
add_subdirectory(batched_gemm_softmax_gemm)
add_subdirectory(batched_gemm_softmax_gemm_permute)
add_subdirectory(grouped_gemm)
add_subdirectory(reduce)
add_subdirectory(convnd_fwd)
add_subdirectory(convnd_bwd_data)
add_subdirectory(grouped_convnd_fwd)
add_subdirectory(grouped_convnd_bwd_weight)
add_subdirectory(block_to_ctile_map)
add_subdirectory(softmax)
add_subdirectory(normalization)
add_subdirectory(data_type)
add_subdirectory(elementwise_normalization)
add_subdirectory(batchnorm)
add_subdirectory(contraction)
add_subdirectory(pool_fwd)
endif()
if(GPU_TARGETS MATCHES "gfx1100")
add_subdirectory(wmma_op)
endif()
add_test_executable(test_jit_library jit_library.cpp)
add_dependencies(test_jit_library jit_library)
target_link_libraries(test_jit_library PRIVATE jit_library)
#include "ck/host/device_gemm_multiple_d.hpp"
#include <iostream>
bool test_Problem()
{
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
false,
true,
false,
{},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
const auto include_header = problem.GetIncludeHeader();
const auto solutions = problem.GetSolutions("gfx90a");
const auto& solution = solutions.at(0);
const auto template_str = solution.template_str;
const auto grid_size = solution.grid_size;
const auto block_size = solution.block_size;
bool pass = true;
pass &= include_header == "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp";
pass &= solutions.size() == 42;
pass &= template_str == "ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle< ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor, ck::half_t, ck::half_t, float, float, ck::Tuple<>, ck::half_t, ck::tensor_operation::element_wise::Passthrough, ck::tensor_operation::element_wise::Passthrough, ck::tensor_operation::element_wise::Passthrough, ck::tensor_operation::device::GemmSpecialization::Default, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, 8, 8, 1, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, 8, 8, 1, 1, 1, ck::Sequence<1,32,1,8>, 8, ck::LoopScheduler::Default, ck::PipelineVersion::v1>";
pass &= grid_size == 2;
pass &= block_size == 256;
return pass;
}
bool test_GetGemmSpec()
{
bool pass = true;
{
//PadMNK
auto problem = ck::host::device_gemm_multiple_d::Problem{255,
255,
255,
false,
true,
false,
{},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
const auto solutions = problem.GetSolutions("gfx90a");
const auto& solution = solutions.at(0);
const auto template_str = solution.template_str;
pass &= template_str.find("GemmSpecialization::MNKPadding") != std::string::npos;
}
{
//Default
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
false,
true,
false,
{},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
const auto solutions = problem.GetSolutions("gfx90a");
const auto& solution = solutions.at(0);
const auto template_str = solution.template_str;
pass &= template_str.find("GemmSpecialization::Default") != std::string::npos;
}
return pass;
}
bool test_GetInstances()
{
bool pass = true;
{
//Col Col Fp16
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
true,
true,
false,
{},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
pass &= problem.GetSolutions("gfx90a").size() == 51;
}
{
//Col Row Fp16
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
true,
false,
false,
{},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
pass &= problem.GetSolutions("gfx90a").size() == 51;
}
{
//Row Col Fp16
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
false,
true,
false,
{},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
pass &= problem.GetSolutions("gfx90a").size() == 42;
}
{
//Row Row Int8
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
false,
false,
false,
{},
ck::host::DataType::Int8,
ck::host::DataType::Int8,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
pass &= problem.GetSolutions("gfx90a").size() == 48;
}
{
//Col Col Int8
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
true,
true,
false,
{},
ck::host::DataType::Int8,
ck::host::DataType::Int8,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
pass &= problem.GetSolutions("gfx90a").size() == 48;
}
{
//Col Row Int8
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
true,
false,
false,
{},
ck::host::DataType::Int8,
ck::host::DataType::Int8,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
pass &= problem.GetSolutions("gfx90a").size() == 48;
}
{
//Row Col Int8
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
false,
true,
false,
{},
ck::host::DataType::Int8,
ck::host::DataType::Int8,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
pass &= problem.GetSolutions("gfx90a").size() == 39;
}
{
//Row Row Int8
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
false,
false,
false,
{},
ck::host::DataType::Int8,
ck::host::DataType::Int8,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
pass &= problem.GetSolutions("gfx90a").size() == 48;
}
return pass;
}
bool test_MakeLayoutsTuple()
{
bool pass = true;
{
// Empty Tuple
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
false,
false,
false,
{},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{ck::host::DataType::Half},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
const auto solutions = problem.GetSolutions("gfx90a");
const auto& solution = solutions.at(0);
const auto template_str = solution.template_str;
pass &= template_str.find("ck::Tuple<>") != std::string::npos;
}
{
// RowColRow Tuple
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
false,
false,
false,
{false, true, false},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{ck::host::DataType::Half},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
const auto solutions = problem.GetSolutions("gfx90a");
const auto& solution = solutions.at(0);
const auto template_str = solution.template_str;
pass &= template_str.find("ck::Tuple<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor>") != std::string::npos;
}
return pass;
}
bool test_MakeTypeTuple()
{
bool pass = true;
{
// Empty Tuple
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
false,
false,
false,
{true},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
const auto solutions = problem.GetSolutions("gfx90a");
const auto& solution = solutions.at(0);
const auto template_str = solution.template_str;
pass &= template_str.find("ck::Tuple<>") != std::string::npos;
}
{
// Half Int8 Tuple
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
false,
false,
false,
{},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{ck::host::DataType::Half, ck::host::DataType::Int8},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
const auto solutions = problem.GetSolutions("gfx90a");
const auto& solution = solutions.at(0);
const auto template_str = solution.template_str;
pass &= template_str.find("ck::Tuple<ck::half_t, int8_t>") != std::string::npos;
}
return pass;
}
int main()
{
bool pass = true;
pass &= test_Problem();
pass &= test_GetGemmSpec();
pass &= test_GetInstances();
pass &= test_MakeLayoutsTuple();
pass &= test_MakeTypeTuple();
if(pass)
{
std::cout << "Test jit library: Pass" << std::endl;
return 0;
}
else
{
std::cout << "Test jit library: Fail" << std::endl;
return -1;
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment