"host/vscode:/vscode.git/clone" did not exist on "80120f0a0c524d1efc0249926a73d5020f0efd67"
Unverified Commit a8629a98 authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge branch 'develop' into gemm_v2r3_kpad_fix

parents 8dc713ea 94bfa502
...@@ -28,6 +28,7 @@ set(PROFILER_SOURCES ...@@ -28,6 +28,7 @@ set(PROFILER_SOURCES
profile_contraction_bilinear.cpp profile_contraction_bilinear.cpp
profile_contraction_scale.cpp profile_contraction_scale.cpp
profile_grouped_conv_bwd_data.cpp profile_grouped_conv_bwd_data.cpp
profile_image_to_column.cpp
) )
if(DL_KERNELS) if(DL_KERNELS)
list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp) list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp)
...@@ -82,6 +83,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_insta ...@@ -82,6 +83,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_insta
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance)
if(DL_KERNELS) if(DL_KERNELS)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance)
endif() endif()
......
...@@ -23,6 +23,7 @@ enum struct GemmDataType ...@@ -23,6 +23,7 @@ enum struct GemmDataType
F16_F16_F16, // 1 F16_F16_F16, // 1
BF16_BF16_BF16, // 2 BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3 INT8_INT8_INT8, // 3
F8_F8_F8, // 4
}; };
#define OP_NAME "gemm" #define OP_NAME "gemm"
...@@ -31,7 +32,7 @@ enum struct GemmDataType ...@@ -31,7 +32,7 @@ enum struct GemmDataType
static void print_helper_msg() static void print_helper_msg()
{ {
std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n" << "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: fp8)\n"
<< "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n" << "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"
<< " 1: A[m, k] * B[n, k] = C[m, n];\n" << " 1: A[m, k] * B[n, k] = C[m, n];\n"
<< " 2: A[k, m] * B[k, n] = C[m, n];\n" << " 2: A[k, m] * B[k, n] = C[m, n];\n"
...@@ -76,6 +77,9 @@ int profile_gemm(int argc, char* argv[]) ...@@ -76,6 +77,9 @@ int profile_gemm(int argc, char* argv[])
using INT8 = int8_t; using INT8 = int8_t;
using INT32 = int32_t; using INT32 = int32_t;
#endif #endif
#ifdef CK_ENABLE_FP8
using F8 = ck::f8_t;
#endif
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -194,6 +198,24 @@ int profile_gemm(int argc, char* argv[]) ...@@ -194,6 +198,24 @@ int profile_gemm(int argc, char* argv[])
{ {
return profile(Col{}, Col{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{}); return profile(Col{}, Col{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{});
} }
#endif
#ifdef CK_ENABLE_FP8
else if(data_type == GemmDataType::F8_F8_F8 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(Row{}, Row{}, Row{}, F8{}, F8{}, F32{}, F8{});
}
else if(data_type == GemmDataType::F8_F8_F8 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(Row{}, Col{}, Row{}, F8{}, F8{}, F32{}, F8{});
}
else if(data_type == GemmDataType::F8_F8_F8 && layout == GemmMatrixLayout::KM_KN_MN)
{
return profile(Col{}, Row{}, Row{}, F8{}, F8{}, F32{}, F8{});
}
else if(data_type == GemmDataType::F8_F8_F8 && layout == GemmMatrixLayout::KM_NK_MN)
{
return profile(Col{}, Col{}, Row{}, F8{}, F8{}, F32{}, F8{});
}
#endif #endif
else else
{ {
......
...@@ -71,6 +71,9 @@ int profile_gemm_bilinear(int argc, char* argv[]) ...@@ -71,6 +71,9 @@ int profile_gemm_bilinear(int argc, char* argv[])
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
using I8 = std::int8_t;
using I32 = std::int32_t;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -141,6 +144,22 @@ int profile_gemm_bilinear(int argc, char* argv[]) ...@@ -141,6 +144,22 @@ int profile_gemm_bilinear(int argc, char* argv[])
{ {
return profile(F16{}, F16{}, F32{}, F16{}, F16{}, Col{}, Col{}, Row{}, Row{}); return profile(F16{}, F16{}, F32{}, F16{}, F16{}, Col{}, Col{}, Row{}, Row{});
} }
else if(data_type == MatrixDataType::INT8_INT8_INT8_INT8 && layout == MatrixLayout::MK_KN_MN_MN)
{
return profile(I8{}, I8{}, I32{}, I8{}, I8{}, Row{}, Row{}, Row{}, Row{});
}
else if(data_type == MatrixDataType::INT8_INT8_INT8_INT8 && layout == MatrixLayout::MK_NK_MN_MN)
{
return profile(I8{}, I8{}, I32{}, I8{}, I8{}, Row{}, Col{}, Row{}, Row{});
}
else if(data_type == MatrixDataType::INT8_INT8_INT8_INT8 && layout == MatrixLayout::KM_KN_MN_MN)
{
return profile(I8{}, I8{}, I32{}, I8{}, I8{}, Col{}, Row{}, Row{}, Row{});
}
else if(data_type == MatrixDataType::INT8_INT8_INT8_INT8 && layout == MatrixLayout::KM_NK_MN_MN)
{
return profile(I8{}, I8{}, I32{}, I8{}, I8{}, Col{}, Col{}, Row{}, Row{});
}
else else
{ {
std::cout << "this data_type & layout is not implemented" << std::endl; std::cout << "this data_type & layout is not implemented" << std::endl;
......
...@@ -59,9 +59,11 @@ int profile_gemm_multiply_add(int argc, char* argv[]) ...@@ -59,9 +59,11 @@ int profile_gemm_multiply_add(int argc, char* argv[])
const int StrideD1 = std::stoi(argv[14]); const int StrideD1 = std::stoi(argv[14]);
const int StrideE = std::stoi(argv[15]); const int StrideE = std::stoi(argv[15]);
using F8 = ck::f8_t;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
#if defined CK_ENABLE_FP8
using F8 = ck::f8_t;
#endif
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -132,6 +134,7 @@ int profile_gemm_multiply_add(int argc, char* argv[]) ...@@ -132,6 +134,7 @@ int profile_gemm_multiply_add(int argc, char* argv[])
{ {
return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{}); return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{});
} }
#if defined CK_ENABLE_FP8
else if(data_type == MatrixDataType::F16_F8_F32_F32_F16 && else if(data_type == MatrixDataType::F16_F8_F32_F32_F16 &&
layout == MatrixLayout::MK_KN_MN_MN_MN) layout == MatrixLayout::MK_KN_MN_MN_MN)
{ {
...@@ -142,6 +145,7 @@ int profile_gemm_multiply_add(int argc, char* argv[]) ...@@ -142,6 +145,7 @@ int profile_gemm_multiply_add(int argc, char* argv[])
{ {
return profile(F16{}, F8{}, F32{}, F32{}, F32{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{}); return profile(F16{}, F8{}, F32{}, F32{}, F32{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{});
} }
#endif
else else
{ {
std::cout << "this data_type & layout is not implemented" << std::endl; std::cout << "this data_type & layout is not implemented" << std::endl;
......
...@@ -67,7 +67,9 @@ int profile_gemm_splitk(int argc, char* argv[]) ...@@ -67,7 +67,9 @@ int profile_gemm_splitk(int argc, char* argv[])
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
using F8 = ck::f8_t; #if defined CK_ENABLE_FP8
using F8 = ck::f8_t;
#endif
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -146,6 +148,7 @@ int profile_gemm_splitk(int argc, char* argv[]) ...@@ -146,6 +148,7 @@ int profile_gemm_splitk(int argc, char* argv[])
{ {
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}); return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{});
} }
#if defined CK_ENABLE_FP8
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}); return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{});
...@@ -178,6 +181,7 @@ int profile_gemm_splitk(int argc, char* argv[]) ...@@ -178,6 +181,7 @@ int profile_gemm_splitk(int argc, char* argv[])
{ {
return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Col{}, Row{}); return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Col{}, Row{});
} }
#endif
else else
{ {
std::cout << "this data_type & layout is not implemented" << std::endl; std::cout << "this data_type & layout is not implemented" << std::endl;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_image_to_column_impl.hpp"
#include "profiler_operation_registry.hpp"
namespace {
enum struct ConvLayout
{
NHWC, // 0
};
enum struct DataType
{
F32_F32, // 0
F16_F16, // 1
BF16_BF16, // 2
INT8_INT8, // 3
};
#define OP_NAME "image_to_column"
#define OP_DESC "Image To Column"
static void print_helper_msg()
{
std::cout
// clang-format off
<< "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n"
<< " 1: Input fp16, Weight fp16, Output fp16\n"
<< " 2: Input bf16, Weight bf16, Output bf16\n"
<< " 3: Input int8, Weight int8, Output int8)\n"
<< "arg3: tensor layout (0: Input[N, Hi, Wi, C], Output[N * Ho * Wo, Y * X * C])\n"
<< "arg4: verification (0: no, 1: yes)\n"
<< "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n"
<< "arg7: time kernel (0: no, 1: yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
// clang-format on
}
} // namespace
int profile_image_to_column(int argc, char* argv[])
{
// 8 for control, 1 for num_dim_spatial
if(argc < 9)
{
print_helper_msg();
return 1;
}
const auto data_type = static_cast<DataType>(std::stoi(argv[2]));
const auto layout = static_cast<ConvLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]);
const int num_dim_spatial = std::stoi(argv[8]);
// 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial
if(argc != 8 + 1 + 4 + 6 * num_dim_spatial)
{
print_helper_msg();
return 1;
}
const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 9, argv);
using F32 = float;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using INT8 = int8_t;
using namespace ck::tensor_layout::convolution;
constexpr auto I1 = ck::Number<1>{};
constexpr auto I2 = ck::Number<2>{};
constexpr auto I3 = ck::Number<3>{};
auto profile = [&](auto num_dim_spatial_tmp, auto in_layout, auto in_type, auto out_type) {
constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value;
using InLayout = decltype(in_layout);
using InDataType = decltype(in_type);
using OutDataType = decltype(out_type);
bool pass = ck::profiler::
profile_image_to_column_impl<NDimSpatial, InLayout, InDataType, OutDataType>(
do_verification, init_method, do_log, time_kernel, params);
return pass ? 0 : 1;
};
// NHWC
if(layout == ConvLayout::NHWC)
{
if(num_dim_spatial == 1)
{
if(data_type == DataType::F32_F32)
{
return profile(I1, GNWC{}, F32{}, F32{});
}
else if(data_type == DataType::F16_F16)
{
return profile(I1, GNWC{}, F16{}, F16{});
}
else if(data_type == DataType::BF16_BF16)
{
return profile(I1, GNWC{}, BF16{}, BF16{});
}
else if(data_type == DataType::INT8_INT8)
{
return profile(I1, GNWC{}, INT8{}, INT8{});
}
}
else if(num_dim_spatial == 2)
{
if(data_type == DataType::F32_F32)
{
return profile(I2, GNHWC{}, F32{}, F32{});
}
else if(data_type == DataType::F16_F16)
{
return profile(I2, GNHWC{}, F16{}, F16{});
}
else if(data_type == DataType::BF16_BF16)
{
return profile(I2, GNHWC{}, BF16{}, BF16{});
}
else if(data_type == DataType::INT8_INT8)
{
return profile(I2, GNHWC{}, INT8{}, INT8{});
}
}
else if(num_dim_spatial == 3)
{
if(data_type == DataType::F32_F32)
{
return profile(I3, GNDHWC{}, F32{}, F32{});
}
else if(data_type == DataType::F16_F16)
{
return profile(I3, GNDHWC{}, F16{}, F16{});
}
else if(data_type == DataType::BF16_BF16)
{
return profile(I3, GNDHWC{}, BF16{}, BF16{});
}
else if(data_type == DataType::INT8_INT8)
{
return profile(I3, GNDHWC{}, INT8{}, INT8{});
}
}
}
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
}
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_image_to_column);
...@@ -9,26 +9,121 @@ add_custom_target(tests) ...@@ -9,26 +9,121 @@ add_custom_target(tests)
function(add_test_executable TEST_NAME) function(add_test_executable TEST_NAME)
message("adding test ${TEST_NAME}") message("adding test ${TEST_NAME}")
add_executable(${TEST_NAME} ${ARGN}) set(result 1)
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}>) if(DEFINED DTYPES)
add_dependencies(tests ${TEST_NAME}) foreach(source IN LISTS ARGN)
add_dependencies(check ${TEST_NAME}) set(test 0)
rocm_install(TARGETS ${TEST_NAME} COMPONENT tests) foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
endif()
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if(test EQUAL 1)
message("removing test ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
endif()
foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl test ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(ARGN)
add_executable(${TEST_NAME} ${ARGN})
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}>)
add_dependencies(tests ${TEST_NAME})
add_dependencies(check ${TEST_NAME})
rocm_install(TARGETS ${TEST_NAME} COMPONENT tests)
set(result 0)
endif()
#message("add_test returns ${result}")
return(PROPAGATE result)
endfunction(add_test_executable TEST_NAME) endfunction(add_test_executable TEST_NAME)
include(GoogleTest) include(GoogleTest)
function(add_gtest_executable TEST_NAME) function(add_gtest_executable TEST_NAME)
message("adding gtest ${TEST_NAME}") message("adding gtest ${TEST_NAME}")
add_executable(${TEST_NAME} ${ARGN}) set(result 1)
add_dependencies(tests ${TEST_NAME}) if(DEFINED DTYPES)
add_dependencies(check ${TEST_NAME}) foreach(source IN LISTS ARGN)
set(test 0)
foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
endif()
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if(test EQUAL 1)
message("removing gtest ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
endif()
foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl test ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(ARGN)
add_executable(${TEST_NAME} ${ARGN})
add_dependencies(tests ${TEST_NAME})
add_dependencies(check ${TEST_NAME})
# suppress gtest warnings # suppress gtest warnings
target_compile_options(${TEST_NAME} PRIVATE -Wno-global-constructors -Wno-undef) target_compile_options(${TEST_NAME} PRIVATE -Wno-global-constructors -Wno-undef)
target_link_libraries(${TEST_NAME} PRIVATE gtest_main) target_link_libraries(${TEST_NAME} PRIVATE gtest_main)
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}>) add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}>)
rocm_install(TARGETS ${TEST_NAME} COMPONENT tests) rocm_install(TARGETS ${TEST_NAME} COMPONENT tests)
set(result 0)
endif()
#message("add_gtest returns ${result}")
return(PROPAGATE result)
endfunction(add_gtest_executable TEST_NAME) endfunction(add_gtest_executable TEST_NAME)
add_subdirectory(magic_number_division) add_subdirectory(magic_number_division)
...@@ -60,6 +155,7 @@ add_subdirectory(contraction) ...@@ -60,6 +155,7 @@ add_subdirectory(contraction)
add_subdirectory(pool) add_subdirectory(pool)
add_subdirectory(batched_gemm_multi_d) add_subdirectory(batched_gemm_multi_d)
add_subdirectory(grouped_convnd_bwd_data) add_subdirectory(grouped_convnd_bwd_data)
add_subdirectory(image_to_column)
if(GPU_TARGETS MATCHES "gfx11") if(GPU_TARGETS MATCHES "gfx11")
add_subdirectory(wmma_op) add_subdirectory(wmma_op)
endif() endif()
...@@ -2,25 +2,21 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) ...@@ -2,25 +2,21 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp)
add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp) if(result EQUAL 0)
target_link_libraries(test_batched_gemm_fp16 PRIVATE utility) target_link_libraries(test_batched_gemm_fp16 PRIVATE utility device_batched_gemm_instance)
target_link_libraries(test_batched_gemm_fp16 PRIVATE device_batched_gemm_instance)
endif() endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) add_test_executable(test_batched_gemm_fp32 batched_gemm_fp32.cpp)
add_test_executable(test_batched_gemm_fp32 batched_gemm_fp32.cpp) if(result EQUAL 0)
target_link_libraries(test_batched_gemm_fp32 PRIVATE utility) target_link_libraries(test_batched_gemm_fp32 PRIVATE utility device_batched_gemm_instance)
target_link_libraries(test_batched_gemm_fp32 PRIVATE device_batched_gemm_instance)
endif() endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) add_test_executable(test_batched_gemm_bf16 batched_gemm_bf16.cpp)
add_test_executable(test_batched_gemm_bf16 batched_gemm_bf16.cpp) if(result EQUAL 0)
target_link_libraries(test_batched_gemm_bf16 PRIVATE utility) target_link_libraries(test_batched_gemm_bf16 PRIVATE utility device_batched_gemm_instance)
target_link_libraries(test_batched_gemm_bf16 PRIVATE device_batched_gemm_instance)
endif() endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) add_test_executable(test_batched_gemm_int8 batched_gemm_int8.cpp)
add_test_executable(test_batched_gemm_int8 batched_gemm_int8.cpp) if(result EQUAL 0)
target_link_libraries(test_batched_gemm_int8 PRIVATE utility) target_link_libraries(test_batched_gemm_int8 PRIVATE utility device_batched_gemm_instance)
target_link_libraries(test_batched_gemm_int8 PRIVATE device_batched_gemm_instance)
endif() endif()
set(target 1) set(target 1)
endif() endif()
......
...@@ -2,12 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) ...@@ -2,12 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_custom_target(test_batched_gemm_gemm)
add_custom_target(test_batched_gemm_gemm) add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp)
add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp) if(result EQUAL 0)
target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance) target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance)
add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16) add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16)
set(target 1) set(target 1)
endif() endif()
endif() endif()
endforeach() endforeach()
\ No newline at end of file
if(DL_KERNELS) add_gtest_executable(test_batched_gemm_multi_d test_batched_gemm_multi_d_dl.cpp)
add_gtest_executable(test_batched_gemm_multi_d test_batched_gemm_multi_d.cpp) if(result EQUAL 0)
target_link_libraries(test_batched_gemm_multi_d PRIVATE utility device_batched_gemm_multi_d_instance) target_link_libraries(test_batched_gemm_multi_d PRIVATE utility device_batched_gemm_multi_d_instance)
endif() endif()
...@@ -2,10 +2,9 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) ...@@ -2,10 +2,9 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp)
add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp) if(result EQUAL 0)
target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility) target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility device_batched_gemm_reduce_instance)
target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE device_batched_gemm_reduce_instance)
set(target 1) set(target 1)
endif() endif()
endif() endif()
......
...@@ -2,12 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) ...@@ -2,12 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_custom_target(test_batched_gemm_softmax_gemm)
add_custom_target(test_batched_gemm_softmax_gemm) add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp)
add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp) if(result EQUAL 0)
target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance) target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance)
add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16) add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16)
set(target 1) set(target 1)
endif() endif()
endif() endif()
endforeach() endforeach()
\ No newline at end of file
...@@ -2,25 +2,28 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) ...@@ -2,25 +2,28 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0) set(target 0)
foreach(gpu IN LISTS GPU_TARGETS) foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0) if(gpu IN_LIST gpu_list AND target EQUAL 0)
if(DTYPES MATCHES "fp16" OR DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) add_custom_target(test_batched_gemm_softmax_gemm_permute)
add_custom_target(test_batched_gemm_softmax_gemm_permute) add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp)
endif() if(result EQUAL 0)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp) add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16)
add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp) endif()
target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp)
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) if(result EQUAL 0)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16) target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16) add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16)
endif() endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp) add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp)
add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp) if(result EQUAL 0)
target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16) endif()
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16) add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp)
endif() if(result EQUAL 0)
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16)
endif()
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
\ No newline at end of file
...@@ -70,10 +70,23 @@ class TestBatchNormBwdRank4 : public ::testing::Test ...@@ -70,10 +70,23 @@ class TestBatchNormBwdRank4 : public ::testing::Test
} }
}; };
using KernelTypes = ::testing::Types<std::tuple<F16, F32, F32, F32, F16, F32, F32>, using KernelTypes = ::testing::Types<
std::tuple<F32, F32, F32, F32, F32, F32, F32>, #ifdef CK_ENABLE_FP16
std::tuple<BF16, F32, F32, F32, BF16, F32, F32>, std::tuple<F16, F32, F32, F32, F16, F32, F32>
std::tuple<F64, F64, F64, F64, F64, F64, F64>>; #endif
#ifdef CK_ENABLE_FP32
,
std::tuple<F32, F32, F32, F32, F32, F32, F32>
#endif
#ifdef CK_ENABLE_BF16
,
std::tuple<BF16, F32, F32, F32, BF16, F32, F32>
#endif
#ifdef CK_ENABLE_FP64
,
std::tuple<F64, F64, F64, F64, F64, F64, F64>
#endif
>;
TYPED_TEST_SUITE(TestBatchNormBwdRank4, KernelTypes); TYPED_TEST_SUITE(TestBatchNormBwdRank4, KernelTypes);
......
...@@ -87,10 +87,23 @@ class TestBatchNormFwdRank4 : public ::testing::Test ...@@ -87,10 +87,23 @@ class TestBatchNormFwdRank4 : public ::testing::Test
} }
}; };
using KernelTypes = ::testing::Types<std::tuple<F16, F16, F32, F16, F16, F32>, using KernelTypes = ::testing::Types<
std::tuple<F32, F32, F32, F32, F32, F32>, #ifdef CK_ENABLE_FP16
std::tuple<BF16, BF16, F32, BF16, BF16, F32>, std::tuple<F16, F16, F32, F16, F16, F32>
std::tuple<F64, F64, F64, F64, F64, F64>>; #endif
#ifdef CK_ENABLE_FP32
,
std::tuple<F32, F32, F32, F32, F32, F32>
#endif
#ifdef CK_ENABLE_BF16
,
std::tuple<BF16, BF16, F32, BF16, BF16, F32>
#endif
#ifdef CK_ENABLE_FP64
,
std::tuple<F64, F64, F64, F64, F64, F64>
#endif
>;
TYPED_TEST_SUITE(TestBatchNormFwdRank4, KernelTypes); TYPED_TEST_SUITE(TestBatchNormFwdRank4, KernelTypes);
......
...@@ -67,10 +67,23 @@ class TestBatchNormInferRank4 : public ::testing::Test ...@@ -67,10 +67,23 @@ class TestBatchNormInferRank4 : public ::testing::Test
} }
}; };
using KernelTypes = ::testing::Types<std::tuple<F16, F16, F32, F16, F16, F32>, using KernelTypes = ::testing::Types<
std::tuple<F32, F32, F32, F32, F32, F32>, #ifdef CK_ENABLE_FP16
std::tuple<BF16, BF16, F32, BF16, BF16, F32>, std::tuple<F16, F16, F32, F16, F16, F32>
std::tuple<F64, F64, F64, F64, F64, F64>>; #endif
#ifdef CK_ENABLE_FP32
,
std::tuple<F32, F32, F32, F32, F32, F32>
#endif
#ifdef CK_ENABLE_BF16
,
std::tuple<BF16, BF16, F32, BF16, BF16, F32>
#endif
#ifdef CK_ENABLE_FP64
,
std::tuple<F64, F64, F64, F64, F64, F64>
#endif
>;
TYPED_TEST_SUITE(TestBatchNormInferRank4, KernelTypes); TYPED_TEST_SUITE(TestBatchNormInferRank4, KernelTypes);
......
if (USE_BITINT_EXTENSION_INT4) if (USE_BITINT_EXTENSION_INT4)
add_gtest_executable(test_int4 int4.cpp) add_gtest_executable(test_int4 int4.cpp)
target_link_libraries(test_int4 PRIVATE utility) if(result EQUAL 0)
target_link_libraries(test_int4 PRIVATE utility)
endif()
endif() endif()
add_gtest_executable(test_fp8 fp8.cpp) add_gtest_executable(test_fp8 fp8.cpp)
target_link_libraries(test_fp8 PRIVATE utility) if(result EQUAL 0)
target_link_libraries(test_fp8 PRIVATE utility)
endif()
add_gtest_executable(test_bf8 bf8.cpp)
if(result EQUAL 0)
target_link_libraries(test_bf8 PRIVATE utility)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using ck::bf8_t;
using ck::f8_convert_sr;
using ck::half_t;
using ck::type_convert;
TEST(BF8, NumericLimits)
{
// constants given for negative zero nan mode
EXPECT_EQ(ck::NumericLimits<bf8_t>::Min(), type_convert<bf8_t>(0x04));
EXPECT_EQ(ck::NumericLimits<bf8_t>::Max(), type_convert<bf8_t>(0x7F));
EXPECT_EQ(ck::NumericLimits<bf8_t>::Lowest(), type_convert<bf8_t>(0xFF));
EXPECT_EQ(ck::NumericLimits<bf8_t>::QuietNaN(), type_convert<bf8_t>(0x80));
}
TEST(BF8, ConvertFP32Nearest)
{
// fix the tolerance value
float abs_tol = 1e-6;
// convert 0 float to bf8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(type_convert<bf8_t>(0.0f)), abs_tol);
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(type_convert<bf8_t>(std::numeric_limits<float>::min())),
abs_tol);
// convert maximal bf8_t to float and check if equal to 57344.0
ASSERT_NEAR(57344.0f, type_convert<float>(type_convert<bf8_t>(57344.0f)), abs_tol);
// convert maximal float to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR(57344.0f,
type_convert<float>(type_convert<bf8_t>(std::numeric_limits<float>::max())),
abs_tol);
// convert inf float to bf8_t and check if it is qNan
ASSERT_NEAR(type_convert<bf8_t>(0x80),
type_convert<bf8_t>(std::numeric_limits<float>::infinity()),
abs_tol);
// positive norm float value to bf8 and back, check if holds
float pos_float = 0.0000762939f;
ASSERT_NEAR(pos_float, type_convert<float>(type_convert<bf8_t>(pos_float)), abs_tol);
// negative norm float value to bf8 and back, check if holds
float neg_float = -0.0000610351f;
ASSERT_NEAR(neg_float, type_convert<float>(type_convert<bf8_t>(neg_float)), abs_tol);
// positive subnorm float value to bf8 and back, check if holds
pos_float = 0.0000305175f;
ASSERT_NEAR(pos_float, type_convert<float>(type_convert<bf8_t>(pos_float)), abs_tol);
// negative subnorm float value to bf8 and back, check if holds
neg_float = -0.0000152587f;
ASSERT_NEAR(neg_float, type_convert<float>(type_convert<bf8_t>(neg_float)), abs_tol);
}
TEST(BF8, ConvertFP32Stochastic)
{
// fix the tolerance value
float abs_tol = 1e-6;
// convert 0 float to bf8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<bf8_t>(0.0f)), abs_tol);
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_sr<bf8_t>(std::numeric_limits<float>::min())),
abs_tol);
// convert maximal bf8_t to float and check if equal to 57344.0
ASSERT_NEAR(57344.0f, type_convert<float>(f8_convert_sr<bf8_t>(57344.0f)), abs_tol);
// convert maximal float to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR(57344.0f,
type_convert<float>(f8_convert_sr<bf8_t>(std::numeric_limits<float>::max())),
abs_tol);
// convert inf float to bf8_t and check if it is qNan
ASSERT_NEAR(type_convert<bf8_t>(0x80),
f8_convert_sr<bf8_t>(std::numeric_limits<float>::infinity()),
abs_tol);
// positive norm float value to bf8 and back, check if holds
float pos_float = 0.0000762939f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_t>(pos_float)), abs_tol);
// negative norm float value to bf8 and back, check if holds
float neg_float = -0.0000610351f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<bf8_t>(neg_float)), abs_tol);
// positive subnorm float value to bf8 and back, check if holds
pos_float = 0.0000305175f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_t>(pos_float)), abs_tol);
// negative subnorm float value to bf8 and back, check if holds
neg_float = -0.0000152587f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<bf8_t>(neg_float)), abs_tol);
}
TEST(BF8, ConvertFP16Nearest)
{
// fix the tolerance value
float abs_tol = 1e-3;
// convert 0 fp16 to bf8 and back, check if holds
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(type_convert<bf8_t>(half_t{0.0})), abs_tol);
// convert minimal fp16 to bf8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(type_convert<bf8_t>(ck::NumericLimits<half_t>::Min())),
abs_tol);
// convert maximal bf8_t to fp16 and check if equal to 57344.0
ASSERT_NEAR(
half_t{57344.0}, type_convert<half_t>(type_convert<bf8_t>(half_t{57344.0})), abs_tol);
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR(half_t{57344.0},
type_convert<half_t>(type_convert<bf8_t>(ck::NumericLimits<half_t>::Max())),
abs_tol);
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN
ASSERT_NEAR(type_convert<bf8_t>(0x80),
type_convert<bf8_t>(ck::NumericLimits<half_t>::QuietNaN()),
abs_tol);
// positive norm fp16 value to bf8 and back, check if holds
half_t pos_half = half_t{0.0000762939};
ASSERT_NEAR(pos_half, type_convert<half_t>(type_convert<bf8_t>(pos_half)), abs_tol);
// negative norm fp16 value to bf8 and back, check if holds
half_t neg_half = half_t{-0.0000610351};
ASSERT_NEAR(neg_half, type_convert<half_t>(type_convert<bf8_t>(neg_half)), abs_tol);
// positive subnorm fp16 value to bf8 and back, check if holds
pos_half = half_t{0.0000305175};
ASSERT_NEAR(pos_half, type_convert<half_t>(type_convert<bf8_t>(pos_half)), abs_tol);
// negative subnorm fp16 value to bf8 and back, check if holds
neg_half = half_t{-0.0000152587};
ASSERT_NEAR(neg_half, type_convert<half_t>(type_convert<bf8_t>(neg_half)), abs_tol);
}
TEST(BF8, ConvertFP16Stochastic)
{
// fix the tolerance value
float abs_tol = 1e-3;
// convert 0 fp16 to bf8 and back, check if holds
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_sr<bf8_t>(half_t{0.0})), abs_tol);
// convert minimal fp16 to bf8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_sr<bf8_t>(ck::NumericLimits<half_t>::Min())),
abs_tol);
// convert maximal bf8_t to fp16 and check if equal to 57344.0
ASSERT_NEAR(
half_t{57344.0}, type_convert<half_t>(f8_convert_sr<bf8_t>(half_t{57344.0})), abs_tol);
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR(half_t{57344.0},
type_convert<half_t>(f8_convert_sr<bf8_t>(ck::NumericLimits<half_t>::Max())),
abs_tol);
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN
ASSERT_NEAR(type_convert<bf8_t>(0x80),
f8_convert_sr<bf8_t>(ck::NumericLimits<half_t>::QuietNaN()),
abs_tol);
// positive norm fp16 value to bf8 and back, check if holds
half_t pos_half = half_t{0.0000762939};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<bf8_t>(pos_half)), abs_tol);
// negative norm fp16 value to bf8 and back, check if holds
half_t neg_half = half_t{-0.0000610351};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<bf8_t>(neg_half)), abs_tol);
// positive subnorm fp16 value to bf8 and back, check if holds
pos_half = half_t{0.0000305175};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<bf8_t>(pos_half)), abs_tol);
// negative subnorm fp16 value to bf8 and back, check if holds
neg_half = half_t{-0.0000152587};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<bf8_t>(neg_half)), abs_tol);
}
...@@ -12,10 +12,11 @@ using ck::type_convert; ...@@ -12,10 +12,11 @@ using ck::type_convert;
TEST(FP8, NumericLimits) TEST(FP8, NumericLimits)
{ {
EXPECT_EQ(ck::NumericLimits<f8_t>::Min(), 0x08); // constants given for negative zero nan mode
EXPECT_EQ(ck::NumericLimits<f8_t>::Max(), 0x77); EXPECT_EQ(ck::NumericLimits<f8_t>::Min(), type_convert<f8_t>(0x08));
EXPECT_EQ(ck::NumericLimits<f8_t>::Lowest(), 0xF7); EXPECT_EQ(ck::NumericLimits<f8_t>::Max(), type_convert<f8_t>(0x7F));
EXPECT_EQ(ck::NumericLimits<f8_t>::QuietNaN(), 0x80); EXPECT_EQ(ck::NumericLimits<f8_t>::Lowest(), type_convert<f8_t>(0xFF));
EXPECT_EQ(ck::NumericLimits<f8_t>::QuietNaN(), type_convert<f8_t>(0x80));
} }
TEST(FP8, ConvertFP32Nearest) TEST(FP8, ConvertFP32Nearest)
...@@ -35,12 +36,20 @@ TEST(FP8, ConvertFP32Nearest) ...@@ -35,12 +36,20 @@ TEST(FP8, ConvertFP32Nearest)
type_convert<float>(type_convert<f8_t>(std::numeric_limits<float>::max())), type_convert<float>(type_convert<f8_t>(std::numeric_limits<float>::max())),
abs_tol); abs_tol);
// convert inf float to f8_t and check if it is qNan // convert inf float to f8_t and check if it is qNan
ASSERT_NEAR(0x80, type_convert<f8_t>(std::numeric_limits<float>::infinity()), abs_tol); ASSERT_NEAR(type_convert<f8_t>(0x80),
// positive float value to fp8 and back, check if holds type_convert<f8_t>(std::numeric_limits<float>::infinity()),
float pos_float = 0.0078125f; abs_tol);
// positive norm float value to fp8 and back, check if holds
float pos_float = 0.017578125f;
ASSERT_NEAR(pos_float, type_convert<float>(type_convert<f8_t>(pos_float)), abs_tol);
// negative norm float value to fp8 and back, check if holds
float neg_float = -0.015625f;
ASSERT_NEAR(neg_float, type_convert<float>(type_convert<f8_t>(neg_float)), abs_tol);
// positive subnorm float value to fp8 and back, check if holds
pos_float = 0.00390625f;
ASSERT_NEAR(pos_float, type_convert<float>(type_convert<f8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(type_convert<f8_t>(pos_float)), abs_tol);
// negative float value to fp8 and back, check if holds // negative subnorm float value to fp8 and back, check if holds
float neg_float = -0.0156250f; neg_float = -0.001953125f;
ASSERT_NEAR(neg_float, type_convert<float>(type_convert<f8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(type_convert<f8_t>(neg_float)), abs_tol);
} }
...@@ -61,12 +70,20 @@ TEST(FP8, ConvertFP32Stochastic) ...@@ -61,12 +70,20 @@ TEST(FP8, ConvertFP32Stochastic)
type_convert<float>(f8_convert_sr<f8_t>(std::numeric_limits<float>::max())), type_convert<float>(f8_convert_sr<f8_t>(std::numeric_limits<float>::max())),
abs_tol); abs_tol);
// convert inf float to f8_t and check if it is qNan // convert inf float to f8_t and check if it is qNan
ASSERT_NEAR(0x80, f8_convert_sr<f8_t>(std::numeric_limits<float>::infinity()), abs_tol); ASSERT_NEAR(type_convert<f8_t>(0x80),
// positive float value to fp8 and back, check if holds f8_convert_sr<f8_t>(std::numeric_limits<float>::infinity()),
float pos_float = 0.0078125f; abs_tol);
// positive norm float value to fp8 and back, check if holds
float pos_float = 0.017578125f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_t>(pos_float)), abs_tol);
// negative norm float value to fp8 and back, check if holds
float neg_float = -0.015625f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_t>(neg_float)), abs_tol);
// positive subnorm float value to fp8 and back, check if holds
pos_float = 0.00390625f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_t>(pos_float)), abs_tol);
// negative float value to fp8 and back, check if holds // negative subnorm float value to fp8 and back, check if holds
float neg_float = -0.0156250f; neg_float = -0.001953125f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_t>(neg_float)), abs_tol);
} }
...@@ -87,12 +104,20 @@ TEST(FP8, ConvertFP16Nearest) ...@@ -87,12 +104,20 @@ TEST(FP8, ConvertFP16Nearest)
type_convert<half_t>(type_convert<f8_t>(ck::NumericLimits<half_t>::Max())), type_convert<half_t>(type_convert<f8_t>(ck::NumericLimits<half_t>::Max())),
abs_tol); abs_tol);
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN // convert QuietNaN fp16 to f8_t and check if it is QuietNaN
ASSERT_NEAR(0x80, type_convert<f8_t>(ck::NumericLimits<half_t>::QuietNaN()), abs_tol); ASSERT_NEAR(type_convert<f8_t>(0x80),
// positive fp16 value to fp8 and back, check if holds type_convert<f8_t>(ck::NumericLimits<half_t>::QuietNaN()),
half_t pos_half = half_t{0.0078125}; abs_tol);
// positive norm fp16 value to fp8 and back, check if holds
half_t pos_half = half_t{0.017578125};
ASSERT_NEAR(pos_half, type_convert<half_t>(type_convert<f8_t>(pos_half)), abs_tol);
// negative norm fp16 value to fp8 and back, check if holds
half_t neg_half = half_t{-0.015625};
ASSERT_NEAR(neg_half, type_convert<half_t>(type_convert<f8_t>(neg_half)), abs_tol);
// positive subnorm fp16 value to fp8 and back, check if holds
pos_half = half_t{0.00390625};
ASSERT_NEAR(pos_half, type_convert<half_t>(type_convert<f8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(type_convert<f8_t>(pos_half)), abs_tol);
// negative fp16 value to fp8 and back, check if holds // negative subnorm fp16 value to fp8 and back, check if holds
half_t neg_half = half_t{-0.0156250}; neg_half = half_t{-0.001953125};
ASSERT_NEAR(neg_half, type_convert<half_t>(type_convert<f8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(type_convert<f8_t>(neg_half)), abs_tol);
} }
...@@ -113,11 +138,19 @@ TEST(FP8, ConvertFP16Stochastic) ...@@ -113,11 +138,19 @@ TEST(FP8, ConvertFP16Stochastic)
type_convert<half_t>(f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::Max())), type_convert<half_t>(f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::Max())),
abs_tol); abs_tol);
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN // convert QuietNaN fp16 to f8_t and check if it is QuietNaN
ASSERT_NEAR(0x80, f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::QuietNaN()), abs_tol); ASSERT_NEAR(type_convert<f8_t>(0x80),
// positive fp16 value to fp8 and back, check if holds f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::QuietNaN()),
half_t pos_half = half_t{0.0078125}; abs_tol);
// positive norm fp16 value to fp8 and back, check if holds
half_t pos_half = half_t{0.017578125};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_t>(pos_half)), abs_tol);
// negative norm fp16 value to fp8 and back, check if holds
half_t neg_half = half_t{-0.015625};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_t>(neg_half)), abs_tol);
// positive subnorm fp16 value to fp8 and back, check if holds
pos_half = half_t{0.00390625};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_t>(pos_half)), abs_tol);
// negative fp16 value to fp8 and back, check if holds // negative subnorm fp16 value to fp8 and back, check if holds
half_t neg_half = half_t{-0.0156250}; neg_half = half_t{-0.001953125};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_t>(neg_half)), abs_tol);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment