"examples/vscode:/vscode.git/clone" did not exist on "2e53936c97d167713c9e97414160124861fa4b68"
Unverified Commit bdddf1ea authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

[CK_TILE] Add error threshold calculation for gemm examples (#1821)

parent 0fcbb25f
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf,
...@@ -148,9 +168,18 @@ int run_gemm_example_with_layouts(int argc, ...@@ -148,9 +168,18 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>( ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_host_ref); a_m_k, b_k_n, c_m_n_host_ref);
const float max_accumulated_value =
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref); *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
} }
else if(arg_parser.get_int("v") == 2) else if(arg_parser.get_int("v") == 2)
...@@ -196,8 +225,18 @@ int run_gemm_example_with_layouts(int argc, ...@@ -196,8 +225,18 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::hip_check_error(hipFree(d_C)); ck_tile::hip_check_error(hipFree(d_C));
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref); const float max_accumulated_value =
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_gpu_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf,
...@@ -179,8 +199,18 @@ int run_batched_gemm_example_with_layouts(int argc, ...@@ -179,8 +199,18 @@ int run_batched_gemm_example_with_layouts(int argc,
ck_tile::reference_batched_gemm<ADataType, BDataType, AccDataType, CDataType>( ck_tile::reference_batched_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_n_k, c_m_n_host_ref); a_m_k, b_n_k, c_m_n_host_ref);
const float max_accumulated_value =
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref); *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
} }
...@@ -240,7 +270,18 @@ int run_batched_gemm_example_with_layouts(int argc, ...@@ -240,7 +270,18 @@ int run_batched_gemm_example_with_layouts(int argc,
ck_tile::hip_check_error(hipFree(d_C)); ck_tile::hip_check_error(hipFree(d_C));
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref); const float max_accumulated_value =
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_gpu_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl; std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float invoke_gemm(int n_warmup, float invoke_gemm(int n_warmup,
int n_repeat, int n_repeat,
...@@ -162,7 +182,18 @@ int run_grouped_gemm_example_with_layouts(int argc, ...@@ -162,7 +182,18 @@ int run_grouped_gemm_example_with_layouts(int argc,
c_m_n_host_ref.SetZero(); c_m_n_host_ref.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>( ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
pass &= ck_tile::check_err(c_m_n_tensors[i], c_m_n_host_ref); const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(Ks[i], 1 /*kbatch*/, max_accumulated_value);
pass &= ck_tile::check_err(c_m_n_tensors[i],
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "gemm[" << i
<< "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
} }
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp" #include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/bit_cast.hpp"
...@@ -376,6 +376,16 @@ struct numeric<bfloat16_t> ...@@ -376,6 +376,16 @@ struct numeric<bfloat16_t>
} }
}; };
template <typename T>
struct numeric_traits;
template <>
struct numeric_traits<bfloat16_t>
{
static constexpr int exp = 8;
static constexpr int mant = 7;
};
#if CK_TILE_USE_CUSTOM_DATA_TYPE #if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bfloat16_t) CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bfloat16_t)
#endif #endif
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -18,6 +18,130 @@ ...@@ -18,6 +18,130 @@
namespace ck_tile { namespace ck_tile {
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
double get_relative_threshold(const int number_of_accumulations = 1)
{
using F8 = ck_tile::fp8_t;
using F16 = ck_tile::half_t;
using BF16 = ck_tile::bf16_t;
using F32 = float;
using I8 = int8_t;
using I32 = int32_t;
static_assert(std::is_same_v<ComputeDataType, F8> || std::is_same_v<ComputeDataType, F16> ||
std::is_same_v<ComputeDataType, BF16> ||
std::is_same_v<ComputeDataType, F32> || std::is_same_v<ComputeDataType, I8> ||
std::is_same_v<ComputeDataType, I32> || std::is_same_v<ComputeDataType, int>,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
double compute_error = 0;
if constexpr(std::is_same_v<ComputeDataType, I8> || std::is_same_v<ComputeDataType, I32> ||
std::is_same_v<ComputeDataType, int>)
{
return 0;
}
else
{
compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5;
}
static_assert(std::is_same_v<OutDataType, F8> || std::is_same_v<OutDataType, F16> ||
std::is_same_v<OutDataType, BF16> || std::is_same_v<OutDataType, F32> ||
std::is_same_v<OutDataType, I8> || std::is_same_v<OutDataType, I32> ||
std::is_same_v<OutDataType, int>,
"Warning: Unhandled OutDataType for setting up the relative threshold!");
double output_error = 0;
if constexpr(std::is_same_v<OutDataType, I8> || std::is_same_v<OutDataType, I32> ||
std::is_same_v<OutDataType, int>)
{
return 0;
}
else
{
output_error = std::pow(2, -numeric_traits<OutDataType>::mant) * 0.5;
}
double midway_error = std::max(compute_error, output_error);
static_assert(std::is_same_v<AccDataType, F8> || std::is_same_v<AccDataType, F16> ||
std::is_same_v<AccDataType, BF16> || std::is_same_v<AccDataType, F32> ||
std::is_same_v<AccDataType, I8> || std::is_same_v<AccDataType, I32> ||
std::is_same_v<AccDataType, int>,
"Warning: Unhandled AccDataType for setting up the relative threshold!");
double acc_error = 0;
if constexpr(std::is_same_v<AccDataType, I8> || std::is_same_v<AccDataType, I32> ||
std::is_same_v<AccDataType, int>)
{
return 0;
}
else
{
acc_error = std::pow(2, -numeric_traits<AccDataType>::mant) * 0.5 * number_of_accumulations;
}
return std::max(acc_error, midway_error);
}
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
{
using F8 = ck_tile::fp8_t;
using F16 = ck_tile::half_t;
using BF16 = ck_tile::bf16_t;
using F32 = float;
using I8 = int8_t;
using I32 = int32_t;
static_assert(std::is_same_v<ComputeDataType, F8> || std::is_same_v<ComputeDataType, F16> ||
std::is_same_v<ComputeDataType, BF16> ||
std::is_same_v<ComputeDataType, F32> || std::is_same_v<ComputeDataType, I8> ||
std::is_same_v<ComputeDataType, I32> || std::is_same_v<ComputeDataType, int>,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
auto expo = std::log2(std::abs(max_possible_num));
double compute_error = 0;
if constexpr(std::is_same_v<ComputeDataType, I8> || std::is_same_v<ComputeDataType, I32> ||
std::is_same_v<ComputeDataType, int>)
{
return 0;
}
else
{
compute_error = std::pow(2, expo - numeric_traits<ComputeDataType>::mant) * 0.5;
}
static_assert(std::is_same_v<OutDataType, F8> || std::is_same_v<OutDataType, F16> ||
std::is_same_v<OutDataType, BF16> || std::is_same_v<OutDataType, F32> ||
std::is_same_v<OutDataType, I8> || std::is_same_v<OutDataType, I32> ||
std::is_same_v<OutDataType, int>,
"Warning: Unhandled OutDataType for setting up the absolute threshold!");
double output_error = 0;
if constexpr(std::is_same_v<OutDataType, I8> || std::is_same_v<OutDataType, I32> ||
std::is_same_v<OutDataType, int>)
{
return 0;
}
else
{
output_error = std::pow(2, expo - numeric_traits<OutDataType>::mant) * 0.5;
}
double midway_error = std::max(compute_error, output_error);
static_assert(std::is_same_v<AccDataType, F8> || std::is_same_v<AccDataType, F16> ||
std::is_same_v<AccDataType, BF16> || std::is_same_v<AccDataType, F32> ||
std::is_same_v<AccDataType, I8> || std::is_same_v<AccDataType, I32> ||
std::is_same_v<AccDataType, int>,
"Warning: Unhandled AccDataType for setting up the absolute threshold!");
double acc_error = 0;
if constexpr(std::is_same_v<AccDataType, I8> || std::is_same_v<AccDataType, I32> ||
std::is_same_v<AccDataType, int>)
{
return 0;
}
else
{
acc_error =
std::pow(2, expo - numeric_traits<AccDataType>::mant) * 0.5 * number_of_accumulations;
}
return std::max(acc_error, midway_error);
}
template <typename T> template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v) std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment