Commit 52cd7ade authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Verify 35_splitk_gemm on floating point numbers.

splitk gemm appears to be losing precision VS reference implementation when FP numbers are involved.
parent e942e568
#pragma once #pragma once
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 2e-1;
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 2e-1;
}
else
{
return 1e-3;
}
}
template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 2e-1;
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 2e-1;
}
else
{
return 1e-3;
}
}
struct ProblemSize final struct ProblemSize final
{ {
ck::index_t M = 3840; ck::index_t M = 3840;
...@@ -15,9 +97,10 @@ struct ProblemSize final ...@@ -15,9 +97,10 @@ struct ProblemSize final
struct ExecutionConfig final struct ExecutionConfig final
{ {
bool do_verification = true; // 0 - no verification, 1 - CPU, 2 - GPU, 3 - CPU + GPU
int init_method = 1; int do_verification = 1;
bool time_kernel = false; int init_method = 7;
bool time_kernel = false;
}; };
bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...@@ -65,6 +148,14 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con ...@@ -65,6 +148,14 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break; break;
case 6:
a_m_k.GenerateTensorValue(GeneratorTensor_PI<ADataType>{});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break;
case 7:
a_m_k.GenerateTensorValue(GeneratorTensor_PI_A<ADataType>{});
b_k_n.GenerateTensorValue(GeneratorTensor_PI_B<BDataType>{});
break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1}); b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
...@@ -123,7 +214,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con ...@@ -123,7 +214,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
bool pass = true; bool pass = true;
if(config.do_verification) if((config.do_verification == 1) || (config.do_verification == 3))
{ {
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
...@@ -142,6 +233,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con ...@@ -142,6 +233,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
std::cout << "Running verification on CPU." << std::endl;
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
if(std::is_same<CDataType, ck::half_t>::value) if(std::is_same<CDataType, ck::half_t>::value)
...@@ -151,8 +243,78 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con ...@@ -151,8 +243,78 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
} }
else else
{ {
pass &= ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); pass &= ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result,
"Error: Incorrect results!",
get_rtol<CDataType>(),
get_atol<CDataType>());
} }
if(pass)
std::cout << "Verification on CPU: PASS" << std::endl;
if(config.init_method == 6 || config.init_method == 7)
{
std::cout << std::fixed << std::setprecision(16);
AccDataType d = ck::type_convert<AccDataType>(c_m_n_device_result(0, 10));
AccDataType h = ck::type_convert<AccDataType>(c_m_n_host_result(10, 0));
std::cout << "device result: " << d << std::endl;
std::cout << "host result: " << h << std::endl;
std::cout << "expected result: " << M_PI << std::endl;
std::cout << "device - host: " << std::abs(d - h) << std::endl;
std::cout << "device - expected: " << std::abs(d - M_PI) << std::endl;
std::cout << "atol: " << get_atol<CDataType>() << std::endl;
std::cout << std::endl << std::endl;
}
}
if((config.do_verification == 2) || (config.do_verification == 3))
{
Tensor<CDataType> c_m_n_device_ref_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
DeviceMem c_m_n_device_ref_buf(sizeof(CDataType) *
c_m_n_device_ref_result.mDesc.GetElementSpaceSize());
// GPU verification
using ReferenceComputeType = float;
using ReferenceGemmInstanceGPU =
ck::tensor_operation::device::ReferenceGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp,
ReferenceComputeType,
ReferenceComputeType>;
auto ref_gemm_gpu = ReferenceGemmInstanceGPU{};
auto ref_invoker_gpu = ref_gemm_gpu.MakeInvoker();
auto ref_argument_gpu = ref_gemm_gpu.MakeArgument(
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_ref_buf.GetDeviceBuffer()),
M,
N,
K,
a_element_op,
b_element_op,
c_element_op);
std::cout << "Running verification on GPU." << std::endl;
ref_invoker_gpu.Run(ref_argument_gpu, StreamConfig{});
c_m_n_device_ref_buf.FromDevice(c_m_n_device_ref_result.mData.data());
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass = ck::utils::check_err(c_m_n_device_result,
c_m_n_device_ref_result,
"Error: Incorrect results!",
get_rtol<CDataType>(),
get_atol<CDataType>());
if(pass)
std::cout << "Verification on GPU: PASS" << std::endl;
} }
if(config.time_kernel) if(config.time_kernel)
...@@ -205,7 +367,7 @@ bool run_splitK_gemm_example(int argc, char* argv[]) ...@@ -205,7 +367,7 @@ bool run_splitK_gemm_example(int argc, char* argv[])
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4: KBatch\n"); printf("arg4: KBatch\n");
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/reference_tensor_operation/gpu/reference_gemm.hpp"
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/reference_tensor_operation/gpu/reference_gemm.hpp"
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/reference_tensor_operation/gpu/reference_gemm.hpp"
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/reference_tensor_operation/gpu/reference_gemm.hpp"
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/reference_tensor_operation/gpu/reference_gemm.hpp"
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/reference_tensor_operation/gpu/reference_gemm.hpp"
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/reference_tensor_operation/gpu/reference_gemm.hpp"
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment