Unverified Commit 755ace59 authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge branch 'develop' into add_int8_wmma_example_instance

parents e24f37fb 63cd4592
...@@ -83,8 +83,8 @@ bool profile_gemm_multiply_add_impl(int do_verification, ...@@ -83,8 +83,8 @@ bool profile_gemm_multiply_add_impl(int do_verification,
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-1, 1}); d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-1, 1});
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 0.2});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.1, 0.1});
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0}); d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0}); d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
} }
......
...@@ -214,6 +214,7 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -214,6 +214,7 @@ bool profile_gemm_splitk_impl(int do_verification,
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch " << " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch "
<< kbatch_curr << std::endl; << kbatch_curr << std::endl;
#if defined CK_ENABLE_FP8
// set softer tolerances for fp8 // set softer tolerances for fp8
if constexpr(is_same_v<ADataType, f8_t> || is_same_v<BDataType, f8_t> || if constexpr(is_same_v<ADataType, f8_t> || is_same_v<BDataType, f8_t> ||
is_same_v<CDataType, f8_t>) is_same_v<CDataType, f8_t>)
...@@ -226,8 +227,11 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -226,8 +227,11 @@ bool profile_gemm_splitk_impl(int do_verification,
} }
else else
{ {
#endif
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
#if defined CK_ENABLE_FP8
} }
#endif
if(tflops > best_tflops) if(tflops > best_tflops)
{ {
......
...@@ -59,9 +59,11 @@ int profile_gemm_multiply_add(int argc, char* argv[]) ...@@ -59,9 +59,11 @@ int profile_gemm_multiply_add(int argc, char* argv[])
const int StrideD1 = std::stoi(argv[14]); const int StrideD1 = std::stoi(argv[14]);
const int StrideE = std::stoi(argv[15]); const int StrideE = std::stoi(argv[15]);
using F8 = ck::f8_t;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
#if defined CK_ENABLE_FP8
using F8 = ck::f8_t;
#endif
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -132,6 +134,7 @@ int profile_gemm_multiply_add(int argc, char* argv[]) ...@@ -132,6 +134,7 @@ int profile_gemm_multiply_add(int argc, char* argv[])
{ {
return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{}); return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{});
} }
#if defined CK_ENABLE_FP8
else if(data_type == MatrixDataType::F16_F8_F32_F32_F16 && else if(data_type == MatrixDataType::F16_F8_F32_F32_F16 &&
layout == MatrixLayout::MK_KN_MN_MN_MN) layout == MatrixLayout::MK_KN_MN_MN_MN)
{ {
...@@ -142,6 +145,7 @@ int profile_gemm_multiply_add(int argc, char* argv[]) ...@@ -142,6 +145,7 @@ int profile_gemm_multiply_add(int argc, char* argv[])
{ {
return profile(F16{}, F8{}, F32{}, F32{}, F32{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{}); return profile(F16{}, F8{}, F32{}, F32{}, F32{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{});
} }
#endif
else else
{ {
std::cout << "this data_type & layout is not implemented" << std::endl; std::cout << "this data_type & layout is not implemented" << std::endl;
......
...@@ -67,7 +67,9 @@ int profile_gemm_splitk(int argc, char* argv[]) ...@@ -67,7 +67,9 @@ int profile_gemm_splitk(int argc, char* argv[])
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
using F8 = ck::f8_t; #if defined CK_ENABLE_FP8
using F8 = ck::f8_t;
#endif
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -146,6 +148,7 @@ int profile_gemm_splitk(int argc, char* argv[]) ...@@ -146,6 +148,7 @@ int profile_gemm_splitk(int argc, char* argv[])
{ {
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}); return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{});
} }
#if defined CK_ENABLE_FP8
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}); return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{});
...@@ -178,6 +181,7 @@ int profile_gemm_splitk(int argc, char* argv[]) ...@@ -178,6 +181,7 @@ int profile_gemm_splitk(int argc, char* argv[])
{ {
return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Col{}, Row{}); return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Col{}, Row{});
} }
#endif
else else
{ {
std::cout << "this data_type & layout is not implemented" << std::endl; std::cout << "this data_type & layout is not implemented" << std::endl;
......
...@@ -3,5 +3,12 @@ if (USE_BITINT_EXTENSION_INT4) ...@@ -3,5 +3,12 @@ if (USE_BITINT_EXTENSION_INT4)
target_link_libraries(test_int4 PRIVATE utility) target_link_libraries(test_int4 PRIVATE utility)
endif() endif()
add_gtest_executable(test_fp8 fp8.cpp) if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES)
target_link_libraries(test_fp8 PRIVATE utility) add_gtest_executable(test_f8 f8.cpp)
target_link_libraries(test_f8 PRIVATE utility)
endif()
if(DTYPES MATCHES "bf8" OR NOT DEFINED DTYPES)
add_gtest_executable(test_bf8 bf8.cpp)
target_link_libraries(test_bf8 PRIVATE utility)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using ck::bf8_t;
using ck::f8_convert_sr;
using ck::half_t;
using ck::type_convert;
TEST(BF8, NumericLimits)
{
// constants given for negative zero nan mode
EXPECT_EQ(ck::NumericLimits<bf8_t>::Min(), type_convert<bf8_t>(0x04));
EXPECT_EQ(ck::NumericLimits<bf8_t>::Max(), type_convert<bf8_t>(0x7F));
EXPECT_EQ(ck::NumericLimits<bf8_t>::Lowest(), type_convert<bf8_t>(0xFF));
EXPECT_EQ(ck::NumericLimits<bf8_t>::QuietNaN(), type_convert<bf8_t>(0x80));
}
TEST(BF8, ConvertFP32Nearest)
{
// fix the tolerance value
float abs_tol = 1e-6;
// convert 0 float to bf8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(type_convert<bf8_t>(0.0f)), abs_tol);
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(type_convert<bf8_t>(std::numeric_limits<float>::min())),
abs_tol);
// convert maximal bf8_t to float and check if equal to 57344.0
ASSERT_NEAR(57344.0f, type_convert<float>(type_convert<bf8_t>(57344.0f)), abs_tol);
// convert maximal float to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR(57344.0f,
type_convert<float>(type_convert<bf8_t>(std::numeric_limits<float>::max())),
abs_tol);
// convert inf float to bf8_t and check if it is qNan
ASSERT_NEAR(type_convert<bf8_t>(0x80),
type_convert<bf8_t>(std::numeric_limits<float>::infinity()),
abs_tol);
// positive norm float value to bf8 and back, check if holds
float pos_float = 0.0000762939f;
ASSERT_NEAR(pos_float, type_convert<float>(type_convert<bf8_t>(pos_float)), abs_tol);
// negative norm float value to bf8 and back, check if holds
float neg_float = -0.0000610351f;
ASSERT_NEAR(neg_float, type_convert<float>(type_convert<bf8_t>(neg_float)), abs_tol);
// positive subnorm float value to bf8 and back, check if holds
pos_float = 0.0000305175f;
ASSERT_NEAR(pos_float, type_convert<float>(type_convert<bf8_t>(pos_float)), abs_tol);
// negative subnorm float value to bf8 and back, check if holds
neg_float = -0.0000152587f;
ASSERT_NEAR(neg_float, type_convert<float>(type_convert<bf8_t>(neg_float)), abs_tol);
}
TEST(BF8, ConvertFP32Stochastic)
{
// fix the tolerance value
float abs_tol = 1e-6;
// convert 0 float to bf8 and back, check if holds
ASSERT_NEAR(0.0f, type_convert<float>(f8_convert_sr<bf8_t>(0.0f)), abs_tol);
// convert minimal float to bf8 and back, check if holds
ASSERT_NEAR(std::numeric_limits<float>::min(),
type_convert<float>(f8_convert_sr<bf8_t>(std::numeric_limits<float>::min())),
abs_tol);
// convert maximal bf8_t to float and check if equal to 57344.0
ASSERT_NEAR(57344.0f, type_convert<float>(f8_convert_sr<bf8_t>(57344.0f)), abs_tol);
// convert maximal float to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR(57344.0f,
type_convert<float>(f8_convert_sr<bf8_t>(std::numeric_limits<float>::max())),
abs_tol);
// convert inf float to bf8_t and check if it is qNan
ASSERT_NEAR(type_convert<bf8_t>(0x80),
f8_convert_sr<bf8_t>(std::numeric_limits<float>::infinity()),
abs_tol);
// positive norm float value to bf8 and back, check if holds
float pos_float = 0.0000762939f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_t>(pos_float)), abs_tol);
// negative norm float value to bf8 and back, check if holds
float neg_float = -0.0000610351f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<bf8_t>(neg_float)), abs_tol);
// positive subnorm float value to bf8 and back, check if holds
pos_float = 0.0000305175f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<bf8_t>(pos_float)), abs_tol);
// negative subnorm float value to bf8 and back, check if holds
neg_float = -0.0000152587f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<bf8_t>(neg_float)), abs_tol);
}
TEST(BF8, ConvertFP16Nearest)
{
// fix the tolerance value
float abs_tol = 1e-3;
// convert 0 fp16 to bf8 and back, check if holds
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(type_convert<bf8_t>(half_t{0.0})), abs_tol);
// convert minimal fp16 to bf8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(type_convert<bf8_t>(ck::NumericLimits<half_t>::Min())),
abs_tol);
// convert maximal bf8_t to fp16 and check if equal to 57344.0
ASSERT_NEAR(
half_t{57344.0}, type_convert<half_t>(type_convert<bf8_t>(half_t{57344.0})), abs_tol);
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR(half_t{57344.0},
type_convert<half_t>(type_convert<bf8_t>(ck::NumericLimits<half_t>::Max())),
abs_tol);
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN
ASSERT_NEAR(type_convert<bf8_t>(0x80),
type_convert<bf8_t>(ck::NumericLimits<half_t>::QuietNaN()),
abs_tol);
// positive norm fp16 value to bf8 and back, check if holds
half_t pos_half = half_t{0.0000762939};
ASSERT_NEAR(pos_half, type_convert<half_t>(type_convert<bf8_t>(pos_half)), abs_tol);
// negative norm fp16 value to bf8 and back, check if holds
half_t neg_half = half_t{-0.0000610351};
ASSERT_NEAR(neg_half, type_convert<half_t>(type_convert<bf8_t>(neg_half)), abs_tol);
// positive subnorm fp16 value to bf8 and back, check if holds
pos_half = half_t{0.0000305175};
ASSERT_NEAR(pos_half, type_convert<half_t>(type_convert<bf8_t>(pos_half)), abs_tol);
// negative subnorm fp16 value to bf8 and back, check if holds
neg_half = half_t{-0.0000152587};
ASSERT_NEAR(neg_half, type_convert<half_t>(type_convert<bf8_t>(neg_half)), abs_tol);
}
TEST(BF8, ConvertFP16Stochastic)
{
// fix the tolerance value
float abs_tol = 1e-3;
// convert 0 fp16 to bf8 and back, check if holds
ASSERT_NEAR(half_t{0.0}, type_convert<half_t>(f8_convert_sr<bf8_t>(half_t{0.0})), abs_tol);
// convert minimal fp16 to bf8 and back, check if holds
ASSERT_NEAR(ck::NumericLimits<half_t>::Min(),
type_convert<half_t>(f8_convert_sr<bf8_t>(ck::NumericLimits<half_t>::Min())),
abs_tol);
// convert maximal bf8_t to fp16 and check if equal to 57344.0
ASSERT_NEAR(
half_t{57344.0}, type_convert<half_t>(f8_convert_sr<bf8_t>(half_t{57344.0})), abs_tol);
// convert maximal fp16 to bf8 and back, check if clipped to 57344.0
ASSERT_NEAR(half_t{57344.0},
type_convert<half_t>(f8_convert_sr<bf8_t>(ck::NumericLimits<half_t>::Max())),
abs_tol);
// convert QuietNaN fp16 to bf8_t and check if it is QuietNaN
ASSERT_NEAR(type_convert<bf8_t>(0x80),
f8_convert_sr<bf8_t>(ck::NumericLimits<half_t>::QuietNaN()),
abs_tol);
// positive norm fp16 value to bf8 and back, check if holds
half_t pos_half = half_t{0.0000762939};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<bf8_t>(pos_half)), abs_tol);
// negative norm fp16 value to bf8 and back, check if holds
half_t neg_half = half_t{-0.0000610351};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<bf8_t>(neg_half)), abs_tol);
// positive subnorm fp16 value to bf8 and back, check if holds
pos_half = half_t{0.0000305175};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<bf8_t>(pos_half)), abs_tol);
// negative subnorm fp16 value to bf8 and back, check if holds
neg_half = half_t{-0.0000152587};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<bf8_t>(neg_half)), abs_tol);
}
...@@ -12,10 +12,11 @@ using ck::type_convert; ...@@ -12,10 +12,11 @@ using ck::type_convert;
TEST(FP8, NumericLimits) TEST(FP8, NumericLimits)
{ {
EXPECT_EQ(ck::NumericLimits<f8_t>::Min(), 0x08); // constants given for negative zero nan mode
EXPECT_EQ(ck::NumericLimits<f8_t>::Max(), 0x77); EXPECT_EQ(ck::NumericLimits<f8_t>::Min(), type_convert<f8_t>(0x08));
EXPECT_EQ(ck::NumericLimits<f8_t>::Lowest(), 0xF7); EXPECT_EQ(ck::NumericLimits<f8_t>::Max(), type_convert<f8_t>(0x7F));
EXPECT_EQ(ck::NumericLimits<f8_t>::QuietNaN(), 0x80); EXPECT_EQ(ck::NumericLimits<f8_t>::Lowest(), type_convert<f8_t>(0xFF));
EXPECT_EQ(ck::NumericLimits<f8_t>::QuietNaN(), type_convert<f8_t>(0x80));
} }
TEST(FP8, ConvertFP32Nearest) TEST(FP8, ConvertFP32Nearest)
...@@ -35,12 +36,20 @@ TEST(FP8, ConvertFP32Nearest) ...@@ -35,12 +36,20 @@ TEST(FP8, ConvertFP32Nearest)
type_convert<float>(type_convert<f8_t>(std::numeric_limits<float>::max())), type_convert<float>(type_convert<f8_t>(std::numeric_limits<float>::max())),
abs_tol); abs_tol);
// convert inf float to f8_t and check if it is qNan // convert inf float to f8_t and check if it is qNan
ASSERT_NEAR(0x80, type_convert<f8_t>(std::numeric_limits<float>::infinity()), abs_tol); ASSERT_NEAR(type_convert<f8_t>(0x80),
// positive float value to fp8 and back, check if holds type_convert<f8_t>(std::numeric_limits<float>::infinity()),
float pos_float = 0.0078125f; abs_tol);
// positive norm float value to fp8 and back, check if holds
float pos_float = 0.017578125f;
ASSERT_NEAR(pos_float, type_convert<float>(type_convert<f8_t>(pos_float)), abs_tol);
// negative norm float value to fp8 and back, check if holds
float neg_float = -0.015625f;
ASSERT_NEAR(neg_float, type_convert<float>(type_convert<f8_t>(neg_float)), abs_tol);
// positive subnorm float value to fp8 and back, check if holds
pos_float = 0.00390625f;
ASSERT_NEAR(pos_float, type_convert<float>(type_convert<f8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(type_convert<f8_t>(pos_float)), abs_tol);
// negative float value to fp8 and back, check if holds // negative subnorm float value to fp8 and back, check if holds
float neg_float = -0.0156250f; neg_float = -0.001953125f;
ASSERT_NEAR(neg_float, type_convert<float>(type_convert<f8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(type_convert<f8_t>(neg_float)), abs_tol);
} }
...@@ -61,12 +70,20 @@ TEST(FP8, ConvertFP32Stochastic) ...@@ -61,12 +70,20 @@ TEST(FP8, ConvertFP32Stochastic)
type_convert<float>(f8_convert_sr<f8_t>(std::numeric_limits<float>::max())), type_convert<float>(f8_convert_sr<f8_t>(std::numeric_limits<float>::max())),
abs_tol); abs_tol);
// convert inf float to f8_t and check if it is qNan // convert inf float to f8_t and check if it is qNan
ASSERT_NEAR(0x80, f8_convert_sr<f8_t>(std::numeric_limits<float>::infinity()), abs_tol); ASSERT_NEAR(type_convert<f8_t>(0x80),
// positive float value to fp8 and back, check if holds f8_convert_sr<f8_t>(std::numeric_limits<float>::infinity()),
float pos_float = 0.0078125f; abs_tol);
// positive norm float value to fp8 and back, check if holds
float pos_float = 0.017578125f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_t>(pos_float)), abs_tol);
// negative norm float value to fp8 and back, check if holds
float neg_float = -0.015625f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_t>(neg_float)), abs_tol);
// positive subnorm float value to fp8 and back, check if holds
pos_float = 0.00390625f;
ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_t>(pos_float)), abs_tol); ASSERT_NEAR(pos_float, type_convert<float>(f8_convert_sr<f8_t>(pos_float)), abs_tol);
// negative float value to fp8 and back, check if holds // negative subnorm float value to fp8 and back, check if holds
float neg_float = -0.0156250f; neg_float = -0.001953125f;
ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_t>(neg_float)), abs_tol); ASSERT_NEAR(neg_float, type_convert<float>(f8_convert_sr<f8_t>(neg_float)), abs_tol);
} }
...@@ -87,12 +104,20 @@ TEST(FP8, ConvertFP16Nearest) ...@@ -87,12 +104,20 @@ TEST(FP8, ConvertFP16Nearest)
type_convert<half_t>(type_convert<f8_t>(ck::NumericLimits<half_t>::Max())), type_convert<half_t>(type_convert<f8_t>(ck::NumericLimits<half_t>::Max())),
abs_tol); abs_tol);
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN // convert QuietNaN fp16 to f8_t and check if it is QuietNaN
ASSERT_NEAR(0x80, type_convert<f8_t>(ck::NumericLimits<half_t>::QuietNaN()), abs_tol); ASSERT_NEAR(type_convert<f8_t>(0x80),
// positive fp16 value to fp8 and back, check if holds type_convert<f8_t>(ck::NumericLimits<half_t>::QuietNaN()),
half_t pos_half = half_t{0.0078125}; abs_tol);
// positive norm fp16 value to fp8 and back, check if holds
half_t pos_half = half_t{0.017578125};
ASSERT_NEAR(pos_half, type_convert<half_t>(type_convert<f8_t>(pos_half)), abs_tol);
// negative norm fp16 value to fp8 and back, check if holds
half_t neg_half = half_t{-0.015625};
ASSERT_NEAR(neg_half, type_convert<half_t>(type_convert<f8_t>(neg_half)), abs_tol);
// positive subnorm fp16 value to fp8 and back, check if holds
pos_half = half_t{0.00390625};
ASSERT_NEAR(pos_half, type_convert<half_t>(type_convert<f8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(type_convert<f8_t>(pos_half)), abs_tol);
// negative fp16 value to fp8 and back, check if holds // negative subnorm fp16 value to fp8 and back, check if holds
half_t neg_half = half_t{-0.0156250}; neg_half = half_t{-0.001953125};
ASSERT_NEAR(neg_half, type_convert<half_t>(type_convert<f8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(type_convert<f8_t>(neg_half)), abs_tol);
} }
...@@ -113,11 +138,19 @@ TEST(FP8, ConvertFP16Stochastic) ...@@ -113,11 +138,19 @@ TEST(FP8, ConvertFP16Stochastic)
type_convert<half_t>(f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::Max())), type_convert<half_t>(f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::Max())),
abs_tol); abs_tol);
// convert QuietNaN fp16 to f8_t and check if it is QuietNaN // convert QuietNaN fp16 to f8_t and check if it is QuietNaN
ASSERT_NEAR(0x80, f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::QuietNaN()), abs_tol); ASSERT_NEAR(type_convert<f8_t>(0x80),
// positive fp16 value to fp8 and back, check if holds f8_convert_sr<f8_t>(ck::NumericLimits<half_t>::QuietNaN()),
half_t pos_half = half_t{0.0078125}; abs_tol);
// positive norm fp16 value to fp8 and back, check if holds
half_t pos_half = half_t{0.017578125};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_t>(pos_half)), abs_tol);
// negative norm fp16 value to fp8 and back, check if holds
half_t neg_half = half_t{-0.015625};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_t>(neg_half)), abs_tol);
// positive subnorm fp16 value to fp8 and back, check if holds
pos_half = half_t{0.00390625};
ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_t>(pos_half)), abs_tol); ASSERT_NEAR(pos_half, type_convert<half_t>(f8_convert_sr<f8_t>(pos_half)), abs_tol);
// negative fp16 value to fp8 and back, check if holds // negative subnorm fp16 value to fp8 and back, check if holds
half_t neg_half = half_t{-0.0156250}; neg_half = half_t{-0.001953125};
ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_t>(neg_half)), abs_tol); ASSERT_NEAR(neg_half, type_convert<half_t>(f8_convert_sr<f8_t>(neg_half)), abs_tol);
} }
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp" #include "profiler/profile_grouped_conv_bwd_weight_impl.hpp"
using namespace ck::tensor_layout::convolution;
template <typename Tuple> template <typename Tuple>
class TestGroupedConvndBwdWeight : public ::testing::Test class TestGroupedConvndBwdWeight : public ::testing::Test
{ {
...@@ -27,28 +29,59 @@ class TestGroupedConvndBwdWeight : public ::testing::Test ...@@ -27,28 +29,59 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
using NDimSpatial = std::tuple_element_t<6, Tuple>; using NDimSpatial = std::tuple_element_t<6, Tuple>;
std::vector<ck::utils::conv::ConvParam> conv_params; std::vector<ck::utils::conv::ConvParam> conv_params;
ck::index_t split_k{2}; std::vector<ck::index_t> split_ks{1, 2};
bool skip_case(const ck::utils::conv::ConvParam& params, const ck::index_t split_k)
{
// Odd K or C values are supported only by DL kernel (only applies to fp16)
// DL kernel currently supports only `split_k=1`
if constexpr(std::is_same_v<InDataType, ck::half_t>)
{
if(split_k != 1 && (params.K_ % 2 != 0 || params.C_ % 2 != 0))
{
return true;
}
}
// 1d NWGC is only supported by DL kernel
// DL kernel is only supported for split_k=1
if constexpr(std::is_same_v<InLayout, NWGC> && std::is_same_v<OutLayout, NWGK>)
{
if(split_k != 1)
{
return true;
}
}
return false;
}
void Run() void Run()
{ {
EXPECT_FALSE(conv_params.empty()); EXPECT_FALSE(conv_params.empty());
bool pass = true; bool pass = true;
for(auto& param : conv_params) for(auto split_k : split_ks)
{ {
pass = pass && ck::profiler::profile_grouped_conv_bwd_weight_impl<NDimSpatial{}, for(auto& param : conv_params)
InLayout, {
WeiLayout, if(!skip_case(param, split_k))
OutLayout, {
InDataType, pass = pass && ck::profiler::profile_grouped_conv_bwd_weight_impl<NDimSpatial{},
WeiDataType, InLayout,
OutDataType>( WeiLayout,
true, // do_verification OutLayout,
1, // init_method: integer value InDataType,
false, // do_log WeiDataType,
false, // time_kernel OutDataType>(
param, true, // do_verification
split_k); 1, // init_method: integer value
false, // do_log
false, // time_kernel
param,
split_k);
}
}
} }
EXPECT_TRUE(pass); EXPECT_TRUE(pass);
} }
...@@ -69,12 +102,13 @@ class TestGroupedConvndBwdWeight3d : public TestGroupedConvndBwdWeight<Tuple> ...@@ -69,12 +102,13 @@ class TestGroupedConvndBwdWeight3d : public TestGroupedConvndBwdWeight<Tuple>
{ {
}; };
using namespace ck::tensor_layout::convolution;
using KernelTypes1d = ::testing::Types< using KernelTypes1d = ::testing::Types<
std::tuple<float, float, float, GNWC, GKXC, GNWK, ck::Number<1>>, std::tuple<float, float, float, GNWC, GKXC, GNWK, ck::Number<1>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, GNWC, GKXC, GNWK, ck::Number<1>>, std::tuple<ck::half_t, ck::half_t, ck::half_t, GNWC, GKXC, GNWK, ck::Number<1>>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNWC, GKXC, GNWK, ck::Number<1>>>; std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNWC, GKXC, GNWK, ck::Number<1>>,
std::tuple<float, float, float, NWGC, GKXC, NWGK, ck::Number<1>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, NWGC, GKXC, NWGK, ck::Number<1>>,
std::tuple<ck::bhalf_t, float, ck::bhalf_t, NWGC, GKXC, NWGK, ck::Number<1>>>;
using KernelTypes2d = ::testing::Types< using KernelTypes2d = ::testing::Types<
std::tuple<float, float, float, GNHWC, GKYXC, GNHWK, ck::Number<2>>, std::tuple<float, float, float, GNHWC, GKYXC, GNHWK, ck::Number<2>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, GNHWC, GKYXC, GNHWK, ck::Number<2>>, std::tuple<ck::half_t, ck::half_t, ck::half_t, GNHWC, GKYXC, GNHWK, ck::Number<2>>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment