Commit ebfa3921 authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/fix_test' into add_mfma_f64

parents 58f4d821 579e8e76
#ifndef GEMM_UTILS_HPP #ifndef GEMM_UTILS_HPP
#define GEMM_UTILS_HPP #define GEMM_UTILS_HPP
#include "check_err.hpp" #include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
namespace ck { namespace ck {
namespace gemm_util { namespace gemm_util {
struct GemmParams struct GemmParams
{ {
GemmParams() GemmParams()
: M(1024), N(1024), K(1024), StrideA(1024), StrideB(1024), StrideC(1024), alpha(1), beta(0) : M(1024), N(1024), K(1024), StrideA(1024), StrideB(1024), StrideC(1024), alpha(1), beta(0)
{ {
} }
ck::index_t M; ck::index_t M;
ck::index_t N; ck::index_t N;
ck::index_t K; ck::index_t K;
ck::index_t StrideA; ck::index_t StrideA;
ck::index_t StrideB; ck::index_t StrideB;
ck::index_t StrideC; ck::index_t StrideC;
float alpha; float alpha;
float beta; float beta;
}; };
template <typename GemmInstance, template <typename GemmInstance,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
void RunHostGEMM(const Tensor<ADataType>& A, void RunHostGEMM(const Tensor<ADataType>& A,
const Tensor<BDataType>& B, const Tensor<BDataType>& B,
Tensor<CDataType>& C, Tensor<CDataType>& C,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
auto ref_gemm = GemmInstance{}; auto ref_gemm = GemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op); auto ref_argument = ref_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
} }
template <typename DeviceGemmPtr_, template <typename DeviceGemmPtr_,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
const ck::gemm_util::GemmParams& params, const ck::gemm_util::GemmParams& params,
const Tensor<ADataType>& A, const Tensor<ADataType>& A,
const Tensor<BDataType>& B, const Tensor<BDataType>& B,
Tensor<CDataType>& C, Tensor<CDataType>& C,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpace()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(A.mData.data()); a_m_k_device_buf.ToDevice(A.mData.data());
b_k_n_device_buf.ToDevice(B.mData.data()); b_k_n_device_buf.ToDevice(B.mData.data());
auto invoker_ptr = gemmPtr->MakeInvokerPointer(); auto invoker_ptr = gemmPtr->MakeInvokerPointer();
auto argument_ptr = auto argument_ptr =
gemmPtr->MakeArgumentPointer(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()), gemmPtr->MakeArgumentPointer(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
params.M, params.M,
params.N, params.N,
params.K, params.K,
params.StrideA, params.StrideA,
params.StrideB, params.StrideB,
params.StrideC, params.StrideC,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); c_element_op);
if(!gemmPtr->IsSupportedArgument(argument_ptr.get())) if(!gemmPtr->IsSupportedArgument(argument_ptr.get()))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does " "wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"); "not support this GEMM problem");
} }
invoker_ptr->Run(argument_ptr.get()); invoker_ptr->Run(argument_ptr.get());
c_m_n_device_buf.FromDevice(C.mData.data()); c_m_n_device_buf.FromDevice(C.mData.data());
} }
template <typename DeviceGemmPtr_, template <typename DeviceGemmPtr_,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AccDataType, typename ALayout,
typename ALayout, typename BLayout,
typename BLayout, typename CLayout,
typename CLayout, typename AElementwiseOperation,
typename AElementwiseOperation, typename BElementwiseOperation,
typename BElementwiseOperation, typename CElementwiseOperation>
typename CElementwiseOperation> struct TestGemm
struct TestGemm {
{ auto PrepareGemmTensor(const ck::gemm_util::GemmParams& params)
auto PrepareGemmTensor(const ck::gemm_util::GemmParams& params) {
{ auto f_host_tensor_descriptor =
auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value) {
{ return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
return HostTensorDescriptor(std::vector<std::size_t>({row, col}), std::vector<std::size_t>({stride, 1}));
std::vector<std::size_t>({stride, 1})); }
} else
else {
{ return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
return HostTensorDescriptor(std::vector<std::size_t>({row, col}), std::vector<std::size_t>({1, stride}));
std::vector<std::size_t>({1, stride})); }
} };
};
Tensor<ADataType> a_m_k(
Tensor<ADataType> a_m_k( f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); Tensor<BDataType> b_k_n(
Tensor<BDataType> b_k_n( f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); Tensor<CDataType> c_m_n_host_result(
Tensor<CDataType> c_m_n_host_result( f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); Tensor<CDataType> c_m_n_device_result(
Tensor<CDataType> c_m_n_device_result( f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
auto f_generate_tensor_value = [](auto& tensor, auto type) {
auto f_generate_tensor_value = [](auto& desc, auto type) { using dataType = decltype(type);
using dataType = decltype(type);
tensor.GenerateTensorValue(GeneratorTensor_2<dataType>{-5, 5});
if(std::is_same<dataType, int8_t>::value || std::is_same<dataType, double>::value) };
{
desc.GenerateTensorValue(GeneratorTensor_2<dataType>{-5, 5}); f_generate_tensor_value(a_m_k, ADataType{});
} f_generate_tensor_value(b_k_n, BDataType{});
else
{ return std::make_tuple(a_m_k, b_k_n, c_m_n_host_result, c_m_n_device_result);
desc.GenerateTensorValue(GeneratorTensor_3<dataType>{-0.5, 0.5}); }
}
}; auto operator()(DeviceGemmPtr_& gemmPtr)
{
f_generate_tensor_value(a_m_k, ADataType{}); std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name
f_generate_tensor_value(b_k_n, BDataType{}); << ", CLayout = " << CLayout{}.name << std::endl;
std::cout << gemmPtr->GetTypeString() << std::endl;
return std::make_tuple(a_m_k, b_k_n, c_m_n_host_result, c_m_n_device_result);
} // Arrange
ck::gemm_util::GemmParams params;
auto operator()(DeviceGemmPtr_& gemmPtr) params.M = 1024;
{ params.N = 1024;
std::cout << "data type: " << typeid(ADataType{}).name() << std::endl; params.K = 1024;
std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name params.StrideA = 1024;
<< ", CLayout = " << CLayout{}.name << std::endl; params.StrideB = 1024;
std::cout << gemmPtr->GetTypeString() << std::endl; params.StrideC = 1024;
// Arrange auto host_tensors = PrepareGemmTensor(params);
ck::gemm_util::GemmParams params;
params.M = 1024; const Tensor<ADataType>& a = std::get<0>(host_tensors);
params.N = 1024; const Tensor<BDataType>& b = std::get<1>(host_tensors);
params.K = 1024; Tensor<CDataType>& c_host = std::get<2>(host_tensors);
params.StrideA = 1024; Tensor<CDataType>& c_device = std::get<3>(host_tensors);
params.StrideB = 1024;
params.StrideC = 1024; auto a_element_op = AElementwiseOperation{};
auto b_element_op = BElementwiseOperation{};
auto host_tensors = PrepareGemmTensor(params); auto c_element_op = CElementwiseOperation{};
const Tensor<ADataType>& a = std::get<0>(host_tensors); using ReferenceGemmInstance =
const Tensor<BDataType>& b = std::get<1>(host_tensors); ck::tensor_operation::host::ReferenceGemm<ADataType,
Tensor<CDataType>& c_host = std::get<2>(host_tensors); BDataType,
Tensor<CDataType>& c_device = std::get<3>(host_tensors); CDataType,
AElementwiseOperation,
auto a_element_op = AElementwiseOperation{}; BElementwiseOperation,
auto b_element_op = BElementwiseOperation{}; CElementwiseOperation>;
auto c_element_op = CElementwiseOperation{}; ck::gemm_util::RunHostGEMM<ReferenceGemmInstance>(
a, b, c_host, a_element_op, b_element_op, c_element_op);
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemm<ADataType, // Act
BDataType, ck::gemm_util::RunDeviceGEMM(
CDataType, gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op);
AccDataType,
AElementwiseOperation, // Assert
BElementwiseOperation, bool res = false;
CElementwiseOperation>; if(std::is_same<CDataType, float>::value)
ck::gemm_util::RunHostGEMM<ReferenceGemmInstance>( {
a, b, c_host, a_element_op, b_element_op, c_element_op); res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
// Act }
ck::gemm_util::RunDeviceGEMM( else if(std::is_same<CDataType, ck::half_t>::value)
gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op); {
res = ck::utils::check_err(c_device.mData, c_host.mData);
// Assert std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
bool res = false; }
if(std::is_same<CDataType, double>::value) else if(std::is_same<CDataType, int8_t>::value)
{ {
res = ck::utils::check_err(c_device.mData, c_host.mData); res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
} }
else if(std::is_same<CDataType, float>::value)
{ return res;
res = ck::utils::check_err(c_device.mData, c_host.mData); }
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; };
}
else if(std::is_same<CDataType, ck::half_t>::value) template <typename DeviceGemmPtr_,
{ typename ALayout,
res = ck::utils::check_err(c_device.mData, c_host.mData); typename BLayout,
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; typename CLayout,
} typename AElementwiseOperation,
else if(std::is_same<CDataType, int8_t>::value) typename BElementwiseOperation,
{ typename CElementwiseOperation>
res = ck::utils::check_err(c_device.mData, c_host.mData); struct TestGemmBF16
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; {
} using BF16 = ck::bhalf_t;
return res; auto PrepareGemmTensorBF16(const ck::gemm_util::GemmParams& params)
} {
}; auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
template <typename DeviceGemmPtr_, if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
typename ALayout, {
typename BLayout, return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
typename CLayout, std::vector<std::size_t>({stride, 1}));
typename AElementwiseOperation, }
typename BElementwiseOperation, else
typename CElementwiseOperation> {
struct TestGemmBF16 return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
{ std::vector<std::size_t>({1, stride}));
using BF16 = ck::bhalf_t; }
};
auto PrepareGemmTensorBF16(const ck::gemm_util::GemmParams& params)
{ // use fp32 host kernel to verify bf16 device kernel
auto f_host_tensor_descriptor = Tensor<BF16> a_m_k_bf16(
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value) Tensor<BF16> b_k_n_bf16(
{ f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
return HostTensorDescriptor(std::vector<std::size_t>({row, col}), Tensor<BF16> c_m_n_device_bf16(
std::vector<std::size_t>({stride, 1})); f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
}
else Tensor<float> a_m_k_fp32(
{ f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
return HostTensorDescriptor(std::vector<std::size_t>({row, col}), Tensor<float> b_k_n_fp32(
std::vector<std::size_t>({1, stride})); f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
} Tensor<float> c_m_n_host_fp32(
}; f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<float> c_m_n_device_fp32(
// use fp32 host kernel to verify bf16 device kernel f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<BF16> a_m_k_bf16(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); a_m_k_bf16.GenerateTensorValue(GeneratorTensor_3<BF16>{-0.5, 0.5});
Tensor<BF16> b_k_n_bf16( b_k_n_bf16.GenerateTensorValue(GeneratorTensor_3<BF16>{-0.5, 0.5});
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
Tensor<BF16> c_m_n_device_bf16( bf16_to_f32_(a_m_k_bf16, a_m_k_fp32);
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); bf16_to_f32_(b_k_n_bf16, b_k_n_fp32);
Tensor<float> a_m_k_fp32( return std::make_tuple(a_m_k_bf16,
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); b_k_n_bf16,
Tensor<float> b_k_n_fp32( c_m_n_device_bf16,
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); a_m_k_fp32,
Tensor<float> c_m_n_host_fp32( b_k_n_fp32,
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); c_m_n_host_fp32,
Tensor<float> c_m_n_device_fp32( c_m_n_device_fp32);
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); }
a_m_k_bf16.GenerateTensorValue(GeneratorTensor_3<BF16>{-0.5, 0.5}); auto operator()(DeviceGemmPtr_& gemmPtr)
b_k_n_bf16.GenerateTensorValue(GeneratorTensor_3<BF16>{-0.5, 0.5}); {
// Arrange
bf16_to_f32_(a_m_k_bf16, a_m_k_fp32); ck::gemm_util::GemmParams params;
bf16_to_f32_(b_k_n_bf16, b_k_n_fp32); params.M = 1024;
params.N = 1024;
return std::make_tuple(a_m_k_bf16, params.K = 1024;
b_k_n_bf16, params.StrideA = 1024;
c_m_n_device_bf16, params.StrideB = 1024;
a_m_k_fp32, params.StrideC = 1024;
b_k_n_fp32,
c_m_n_host_fp32, auto host_tensors = PrepareGemmTensorBF16(params);
c_m_n_device_fp32); const Tensor<BF16>& a_bf16 = std::get<0>(host_tensors);
} const Tensor<BF16>& b_bf16 = std::get<1>(host_tensors);
Tensor<BF16>& c_device_bf16 = std::get<2>(host_tensors);
auto operator()(DeviceGemmPtr_& gemmPtr) Tensor<float>& a_fp32 = std::get<3>(host_tensors);
{ Tensor<float>& b_fp32 = std::get<4>(host_tensors);
// Arrange Tensor<float>& c_host_fp32 = std::get<5>(host_tensors);
ck::gemm_util::GemmParams params; Tensor<float>& c_device_fp32 = std::get<6>(host_tensors);
params.M = 1024;
params.N = 1024; auto a_element_op = AElementwiseOperation{};
params.K = 1024; auto b_element_op = BElementwiseOperation{};
params.StrideA = 1024; auto c_element_op = CElementwiseOperation{};
params.StrideB = 1024;
params.StrideC = 1024; // use fp32 host kernel to verify bf16 device kernel
using ReferenceGemmInstance =
auto host_tensors = PrepareGemmTensorBF16(params); ck::tensor_operation::host::ReferenceGemm<float,
const Tensor<BF16>& a_bf16 = std::get<0>(host_tensors); float,
const Tensor<BF16>& b_bf16 = std::get<1>(host_tensors); float,
Tensor<BF16>& c_device_bf16 = std::get<2>(host_tensors); AElementwiseOperation,
Tensor<float>& a_fp32 = std::get<3>(host_tensors); BElementwiseOperation,
Tensor<float>& b_fp32 = std::get<4>(host_tensors); CElementwiseOperation>;
Tensor<float>& c_host_fp32 = std::get<5>(host_tensors); ck::gemm_util::RunHostGEMM<ReferenceGemmInstance>(
Tensor<float>& c_device_fp32 = std::get<6>(host_tensors); a_fp32, b_fp32, c_host_fp32, a_element_op, b_element_op, c_element_op);
auto a_element_op = AElementwiseOperation{}; // Act
auto b_element_op = BElementwiseOperation{}; ck::gemm_util::RunDeviceGEMM(gemmPtr,
auto c_element_op = CElementwiseOperation{}; params,
a_bf16,
// use fp32 host kernel to verify bf16 device kernel b_bf16,
using ReferenceGemmInstance = c_device_bf16,
ck::tensor_operation::host::ReferenceGemm<float, a_element_op,
float, b_element_op,
float, c_element_op);
float,
AElementwiseOperation, bf16_to_f32_(c_device_bf16, c_device_fp32);
BElementwiseOperation,
CElementwiseOperation>; // Assert
ck::gemm_util::RunHostGEMM<ReferenceGemmInstance>( bool res = ck::utils::check_err(
a_fp32, b_fp32, c_host_fp32, a_element_op, b_element_op, c_element_op); c_device_fp32.mData, c_host_fp32.mData, "Error: incorrect results!", 1e-2f, 1e-3f);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
// Act
ck::gemm_util::RunDeviceGEMM(gemmPtr, return res;
params, };
a_bf16, };
b_bf16,
c_device_bf16, } // namespace gemm_util
a_element_op, } // namespace ck
b_element_op, #endif
c_element_op);
bf16_to_f32_(c_device_bf16, c_device_fp32);
// Assert
bool res = ck::utils::check_err(
c_device_fp32.mData, c_host_fp32.mData, "Error: incorrect results!", 1e-2f, 1e-3f);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res;
};
};
} // namespace gemm_util
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment