Commit 8e0beb65 authored by Alan Turner's avatar Alan Turner
Browse files

Add unit tests

parent fcca3307
...@@ -249,6 +249,8 @@ include_directories(BEFORE ...@@ -249,6 +249,8 @@ include_directories(BEFORE
${HIP_INCLUDE_DIRS} ${HIP_INCLUDE_DIRS}
) )
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
if (NOT CK_BUILD_JIT_LIB) if (NOT CK_BUILD_JIT_LIB)
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV) if(BUILD_DEV)
...@@ -257,7 +259,7 @@ if (NOT CK_BUILD_JIT_LIB) ...@@ -257,7 +259,7 @@ if (NOT CK_BUILD_JIT_LIB)
endif() endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp") file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp")
file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*) file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*)
...@@ -286,7 +288,6 @@ if (NOT CK_BUILD_JIT_LIB) ...@@ -286,7 +288,6 @@ if (NOT CK_BUILD_JIT_LIB)
add_subdirectory(example) add_subdirectory(example)
add_subdirectory(test)
add_subdirectory(profiler) add_subdirectory(profiler)
else() else()
...@@ -297,7 +298,9 @@ else() ...@@ -297,7 +298,9 @@ else()
endif() endif()
add_subdirectory(library) add_subdirectory(library)
add_subdirectory(test)
#Create an interface target for the include only files and call it "composablekernels" #Create an interface target for the include only files and call it "composablekernels"
include(CMakePackageConfigHelpers) include(CMakePackageConfigHelpers)
......
...@@ -49,7 +49,7 @@ std::vector<std::string> Problem::GetInstances(const std::string& arch) const ...@@ -49,7 +49,7 @@ std::vector<std::string> Problem::GetInstances(const std::string& arch) const
const bool quantize = ADataType == DataType::Int8 and BDataType == DataType::Int8; const bool quantize = ADataType == DataType::Int8 and BDataType == DataType::Int8;
if (get_xdlop_archs().find(arch) != get_xdlop_archs().end()) if (get_xdlop_archs().find(arch) != get_xdlop_archs().end())
{ {
instance::gemm_add_add_fastgelu_instances all_instances{}; ck::tensor_operation::device::instance::gemm_add_add_fastgelu_instances all_instances{};
if(TransA and TransB) if(TransA and TransB)
instances = all_instances.get_col_col_instances(quantize); instances = all_instances.get_col_col_instances(quantize);
else if(TransA and not TransB) else if(TransA and not TransB)
...@@ -139,7 +139,7 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const ...@@ -139,7 +139,7 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
std::string Problem::GetIncludeHeader() const std::string Problem::GetIncludeHeader() const
{ {
return instance::gemm_add_add_fastgelu_instances{}.get_include_header(); return ck::tensor_operation::device::instance::gemm_add_add_fastgelu_instances{}.get_include_header();
} }
std::vector<Solution> Problem::GetSolutions(const std::string& arch) const std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
......
...@@ -31,33 +31,37 @@ function(add_gtest_executable TEST_NAME) ...@@ -31,33 +31,37 @@ function(add_gtest_executable TEST_NAME)
rocm_install(TARGETS ${TEST_NAME} COMPONENT tests) rocm_install(TARGETS ${TEST_NAME} COMPONENT tests)
endfunction(add_gtest_executable TEST_NAME) endfunction(add_gtest_executable TEST_NAME)
add_subdirectory(magic_number_division) if(CK_BUILD_JIT_LIB)
add_subdirectory(space_filling_curve) add_subdirectory(jit_library)
add_subdirectory(conv_util) else()
add_subdirectory(reference_conv_fwd) add_subdirectory(magic_number_division)
add_subdirectory(gemm) add_subdirectory(space_filling_curve)
add_subdirectory(gemm_layernorm) add_subdirectory(conv_util)
add_subdirectory(gemm_split_k) add_subdirectory(reference_conv_fwd)
add_subdirectory(gemm_reduce) add_subdirectory(gemm)
add_subdirectory(batched_gemm) add_subdirectory(gemm_layernorm)
add_subdirectory(batched_gemm_reduce) add_subdirectory(gemm_split_k)
add_subdirectory(batched_gemm_gemm) add_subdirectory(gemm_reduce)
add_subdirectory(batched_gemm_softmax_gemm) add_subdirectory(batched_gemm)
add_subdirectory(batched_gemm_softmax_gemm_permute) add_subdirectory(batched_gemm_reduce)
add_subdirectory(grouped_gemm) add_subdirectory(batched_gemm_gemm)
add_subdirectory(reduce) add_subdirectory(batched_gemm_softmax_gemm)
add_subdirectory(convnd_fwd) add_subdirectory(batched_gemm_softmax_gemm_permute)
add_subdirectory(convnd_bwd_data) add_subdirectory(grouped_gemm)
add_subdirectory(grouped_convnd_fwd) add_subdirectory(reduce)
add_subdirectory(grouped_convnd_bwd_weight) add_subdirectory(convnd_fwd)
add_subdirectory(block_to_ctile_map) add_subdirectory(convnd_bwd_data)
add_subdirectory(softmax) add_subdirectory(grouped_convnd_fwd)
add_subdirectory(normalization) add_subdirectory(grouped_convnd_bwd_weight)
add_subdirectory(data_type) add_subdirectory(block_to_ctile_map)
add_subdirectory(elementwise_normalization) add_subdirectory(softmax)
add_subdirectory(batchnorm) add_subdirectory(normalization)
add_subdirectory(contraction) add_subdirectory(data_type)
add_subdirectory(pool_fwd) add_subdirectory(elementwise_normalization)
add_subdirectory(batchnorm)
add_subdirectory(contraction)
add_subdirectory(pool_fwd)
endif()
if(GPU_TARGETS MATCHES "gfx1100") if(GPU_TARGETS MATCHES "gfx1100")
add_subdirectory(wmma_op) add_subdirectory(wmma_op)
endif() endif()
add_test_executable(test_jit_library jit_library.cpp)
add_dependencies(test_jit_library jit_library)
target_link_libraries(test_jit_library PRIVATE jit_library)
#include "ck/host/device_gemm_multiple_d.hpp"
#include <iostream>
bool test_Problem()
{
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
false,
true,
false,
{},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
const auto include_header = problem.GetIncludeHeader();
const auto solutions = problem.GetSolutions("gfx90a");
const auto& solution = solutions.at(0);
const auto template_str = solution.template_str;
const auto grid_size = solution.grid_size;
const auto block_size = solution.block_size;
bool pass = true;
pass &= include_header == "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp";
pass &= solutions.size() == 42;
pass &= template_str == "ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle< ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor, ck::half_t, ck::half_t, float, float, ck::Tuple<>, ck::half_t, ck::tensor_operation::element_wise::Passthrough, ck::tensor_operation::element_wise::Passthrough, ck::tensor_operation::element_wise::Passthrough, ck::tensor_operation::device::GemmSpecialization::Default, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, 8, 8, 1, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, 8, 8, 1, 1, 1, ck::Sequence<1,32,1,8>, 8, ck::LoopScheduler::Default, ck::PipelineVersion::v1>";
pass &= grid_size == 2;
pass &= block_size == 256;
return pass;
}
bool test_GetGemmSpec()
{
bool pass = true;
{
//PadMNK
auto problem = ck::host::device_gemm_multiple_d::Problem{255,
255,
255,
false,
true,
false,
{},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
const auto solutions = problem.GetSolutions("gfx90a");
const auto& solution = solutions.at(0);
const auto template_str = solution.template_str;
pass &= template_str.find("GemmSpecialization::MNKPadding") != std::string::npos;
}
{
//Default
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
false,
true,
false,
{},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
const auto solutions = problem.GetSolutions("gfx90a");
const auto& solution = solutions.at(0);
const auto template_str = solution.template_str;
pass &= template_str.find("GemmSpecialization::Default") != std::string::npos;
}
return pass;
}
bool test_GetInstances()
{
bool pass = true;
{
//Col Col Fp16
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
true,
true,
false,
{},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
pass &= problem.GetSolutions("gfx90a").size() == 51;
}
{
//Col Row Fp16
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
true,
false,
false,
{},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
pass &= problem.GetSolutions("gfx90a").size() == 51;
}
{
//Row Col Fp16
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
false,
true,
false,
{},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
pass &= problem.GetSolutions("gfx90a").size() == 42;
}
{
//Row Row Int8
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
false,
false,
false,
{},
ck::host::DataType::Int8,
ck::host::DataType::Int8,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
pass &= problem.GetSolutions("gfx90a").size() == 48;
}
{
//Col Col Int8
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
true,
true,
false,
{},
ck::host::DataType::Int8,
ck::host::DataType::Int8,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
pass &= problem.GetSolutions("gfx90a").size() == 48;
}
{
//Col Row Int8
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
true,
false,
false,
{},
ck::host::DataType::Int8,
ck::host::DataType::Int8,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
pass &= problem.GetSolutions("gfx90a").size() == 48;
}
{
//Row Col Int8
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
false,
true,
false,
{},
ck::host::DataType::Int8,
ck::host::DataType::Int8,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
pass &= problem.GetSolutions("gfx90a").size() == 39;
}
{
//Row Row Int8
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
false,
false,
false,
{},
ck::host::DataType::Int8,
ck::host::DataType::Int8,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
pass &= problem.GetSolutions("gfx90a").size() == 48;
}
return pass;
}
bool test_MakeLayoutsTuple()
{
bool pass = true;
{
// Empty Tuple
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
false,
false,
false,
{},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{ck::host::DataType::Half},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
const auto solutions = problem.GetSolutions("gfx90a");
const auto& solution = solutions.at(0);
const auto template_str = solution.template_str;
pass &= template_str.find("ck::Tuple<>") != std::string::npos;
}
{
// RowColRow Tuple
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
false,
false,
false,
{false, true, false},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{ck::host::DataType::Half},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
const auto solutions = problem.GetSolutions("gfx90a");
const auto& solution = solutions.at(0);
const auto template_str = solution.template_str;
pass &= template_str.find("ck::Tuple<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor>") != std::string::npos;
}
return pass;
}
bool test_MakeTypeTuple()
{
bool pass = true;
{
// Empty Tuple
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
false,
false,
false,
{true},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
const auto solutions = problem.GetSolutions("gfx90a");
const auto& solution = solutions.at(0);
const auto template_str = solution.template_str;
pass &= template_str.find("ck::Tuple<>") != std::string::npos;
}
{
// Half Int8 Tuple
auto problem = ck::host::device_gemm_multiple_d::Problem{256,
256,
256,
false,
false,
false,
{},
ck::host::DataType::Half,
ck::host::DataType::Half,
ck::host::DataType::Half,
{ck::host::DataType::Half, ck::host::DataType::Int8},
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough",
"ck::tensor_operation::element_wise::Passthrough"};
const auto solutions = problem.GetSolutions("gfx90a");
const auto& solution = solutions.at(0);
const auto template_str = solution.template_str;
pass &= template_str.find("ck::Tuple<ck::half_t, int8_t>") != std::string::npos;
}
return pass;
}
int main()
{
bool pass = true;
pass &= test_Problem();
pass &= test_GetGemmSpec();
pass &= test_GetInstances();
pass &= test_MakeLayoutsTuple();
pass &= test_MakeTypeTuple();
if(pass)
{
std::cout << "Test jit library: Pass" << std::endl;
return 0;
}
else
{
std::cout << "Test jit library: Fail" << std::endl;
return -1;
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment