Unverified Commit 95e93430 authored by Anthony Chang's avatar Anthony Chang Committed by GitHub
Browse files

Hotfix for gemm test (#214)

* pass by ref to avoid throwing away initialization results

* EOL CRLF -> LF
parent 3956085d
#ifndef GEMM_UTILS_HPP #ifndef GEMM_UTILS_HPP
#define GEMM_UTILS_HPP #define GEMM_UTILS_HPP
#include "check_err.hpp" #include "check_err.hpp"
#include "config.hpp" #include "config.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
namespace ck { namespace ck {
namespace gemm_util { namespace gemm_util {
struct GemmParams struct GemmParams
{ {
GemmParams() GemmParams()
: M(1024), N(1024), K(1024), StrideA(1024), StrideB(1024), StrideC(1024), alpha(1), beta(0) : M(1024), N(1024), K(1024), StrideA(1024), StrideB(1024), StrideC(1024), alpha(1), beta(0)
{ {
} }
ck::index_t M; ck::index_t M;
ck::index_t N; ck::index_t N;
ck::index_t K; ck::index_t K;
ck::index_t StrideA; ck::index_t StrideA;
ck::index_t StrideB; ck::index_t StrideB;
ck::index_t StrideC; ck::index_t StrideC;
float alpha; float alpha;
float beta; float beta;
}; };
template <typename GemmInstance, template <typename GemmInstance,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
void RunHostGEMM(const Tensor<ADataType>& A, void RunHostGEMM(const Tensor<ADataType>& A,
const Tensor<BDataType>& B, const Tensor<BDataType>& B,
Tensor<CDataType>& C, Tensor<CDataType>& C,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
auto ref_gemm = GemmInstance{}; auto ref_gemm = GemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op); auto ref_argument = ref_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
} }
template <typename DeviceGemmPtr_, template <typename DeviceGemmPtr_,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
const ck::gemm_util::GemmParams& params, const ck::gemm_util::GemmParams& params,
const Tensor<ADataType>& A, const Tensor<ADataType>& A,
const Tensor<BDataType>& B, const Tensor<BDataType>& B,
Tensor<CDataType>& C, Tensor<CDataType>& C,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpace()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(A.mData.data()); a_m_k_device_buf.ToDevice(A.mData.data());
b_k_n_device_buf.ToDevice(B.mData.data()); b_k_n_device_buf.ToDevice(B.mData.data());
auto invoker_ptr = gemmPtr->MakeInvokerPointer(); auto invoker_ptr = gemmPtr->MakeInvokerPointer();
auto argument_ptr = auto argument_ptr =
gemmPtr->MakeArgumentPointer(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()), gemmPtr->MakeArgumentPointer(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
params.M, params.M,
params.N, params.N,
params.K, params.K,
params.StrideA, params.StrideA,
params.StrideB, params.StrideB,
params.StrideC, params.StrideC,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); c_element_op);
if(!gemmPtr->IsSupportedArgument(argument_ptr.get())) if(!gemmPtr->IsSupportedArgument(argument_ptr.get()))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does " "wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"); "not support this GEMM problem");
} }
invoker_ptr->Run(argument_ptr.get()); invoker_ptr->Run(argument_ptr.get());
c_m_n_device_buf.FromDevice(C.mData.data()); c_m_n_device_buf.FromDevice(C.mData.data());
} }
template <typename DeviceGemmPtr_, template <typename DeviceGemmPtr_,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
struct TestGemm struct TestGemm
{ {
auto PrepareGemmTensor(const ck::gemm_util::GemmParams& params) auto PrepareGemmTensor(const ck::gemm_util::GemmParams& params)
{ {
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value) if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({row, col}), return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1})); std::vector<std::size_t>({stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({row, col}), return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride})); std::vector<std::size_t>({1, stride}));
} }
}; };
Tensor<ADataType> a_m_k( Tensor<ADataType> a_m_k(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
Tensor<BDataType> b_k_n( Tensor<BDataType> b_k_n(
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result( Tensor<CDataType> c_m_n_host_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result( Tensor<CDataType> c_m_n_device_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
auto f_generate_tensor_value = [](auto desc, auto type) { auto f_generate_tensor_value = [](auto& desc, auto type) {
using dataType = decltype(type); using dataType = decltype(type);
if(std::is_same<dataType, int8_t>::value) if(std::is_same<dataType, int8_t>::value)
{ {
desc.GenerateTensorValue(GeneratorTensor_2<int8_t>{-5, 5}); desc.GenerateTensorValue(GeneratorTensor_2<int8_t>{-5, 5});
} }
else else
{ {
desc.GenerateTensorValue(GeneratorTensor_3<dataType>{-0.5, 0.5}); desc.GenerateTensorValue(GeneratorTensor_3<dataType>{-0.5, 0.5});
} }
}; };
f_generate_tensor_value(a_m_k, ADataType{}); f_generate_tensor_value(a_m_k, ADataType{});
f_generate_tensor_value(b_k_n, BDataType{}); f_generate_tensor_value(b_k_n, BDataType{});
return std::make_tuple(a_m_k, b_k_n, c_m_n_host_result, c_m_n_device_result); return std::make_tuple(a_m_k, b_k_n, c_m_n_host_result, c_m_n_device_result);
} }
auto operator()(DeviceGemmPtr_& gemmPtr) auto operator()(DeviceGemmPtr_& gemmPtr)
{ {
std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name
<< ", CLayout = " << CLayout{}.name << std::endl; << ", CLayout = " << CLayout{}.name << std::endl;
std::cout << gemmPtr->GetTypeString() << std::endl; std::cout << gemmPtr->GetTypeString() << std::endl;
// Arrange // Arrange
ck::gemm_util::GemmParams params; ck::gemm_util::GemmParams params;
params.M = 1024; params.M = 1024;
params.N = 1024; params.N = 1024;
params.K = 1024; params.K = 1024;
params.StrideA = 1024; params.StrideA = 1024;
params.StrideB = 1024; params.StrideB = 1024;
params.StrideC = 1024; params.StrideC = 1024;
auto host_tensors = PrepareGemmTensor(params); auto host_tensors = PrepareGemmTensor(params);
const Tensor<ADataType>& a = std::get<0>(host_tensors); const Tensor<ADataType>& a = std::get<0>(host_tensors);
const Tensor<BDataType>& b = std::get<1>(host_tensors); const Tensor<BDataType>& b = std::get<1>(host_tensors);
Tensor<CDataType>& c_host = std::get<2>(host_tensors); Tensor<CDataType>& c_host = std::get<2>(host_tensors);
Tensor<CDataType>& c_device = std::get<3>(host_tensors); Tensor<CDataType>& c_device = std::get<3>(host_tensors);
auto a_element_op = AElementwiseOperation{}; auto a_element_op = AElementwiseOperation{};
auto b_element_op = BElementwiseOperation{}; auto b_element_op = BElementwiseOperation{};
auto c_element_op = CElementwiseOperation{}; auto c_element_op = CElementwiseOperation{};
using ReferenceGemmInstance = using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemm<ADataType, ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation>; CElementwiseOperation>;
ck::gemm_util::RunHostGEMM<ReferenceGemmInstance>( ck::gemm_util::RunHostGEMM<ReferenceGemmInstance>(
a, b, c_host, a_element_op, b_element_op, c_element_op); a, b, c_host, a_element_op, b_element_op, c_element_op);
// Act // Act
ck::gemm_util::RunDeviceGEMM( ck::gemm_util::RunDeviceGEMM(
gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op); gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op);
// Assert // Assert
bool res = false; bool res = false;
if(std::is_same<CDataType, float>::value) if(std::is_same<CDataType, float>::value)
{ {
res = ck::utils::check_err(c_device.mData, c_host.mData); res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
} }
else if(std::is_same<CDataType, ck::half_t>::value) else if(std::is_same<CDataType, ck::half_t>::value)
{ {
res = ck::utils::check_err(c_device.mData, c_host.mData); res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
} }
else if(std::is_same<CDataType, int8_t>::value) else if(std::is_same<CDataType, int8_t>::value)
{ {
res = ck::utils::check_err(c_device.mData, c_host.mData); res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
} }
return res; return res;
} }
}; };
template <typename DeviceGemmPtr_, template <typename DeviceGemmPtr_,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
struct TestGemmBF16 struct TestGemmBF16
{ {
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
auto PrepareGemmTensorBF16(const ck::gemm_util::GemmParams& params) auto PrepareGemmTensorBF16(const ck::gemm_util::GemmParams& params)
{ {
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value) if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({row, col}), return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1})); std::vector<std::size_t>({stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({row, col}), return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride})); std::vector<std::size_t>({1, stride}));
} }
}; };
// use fp32 host kernel to verify bf16 device kernel // use fp32 host kernel to verify bf16 device kernel
Tensor<BF16> a_m_k_bf16( Tensor<BF16> a_m_k_bf16(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
Tensor<BF16> b_k_n_bf16( Tensor<BF16> b_k_n_bf16(
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
Tensor<BF16> c_m_n_device_bf16( Tensor<BF16> c_m_n_device_bf16(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<float> a_m_k_fp32( Tensor<float> a_m_k_fp32(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
Tensor<float> b_k_n_fp32( Tensor<float> b_k_n_fp32(
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
Tensor<float> c_m_n_host_fp32( Tensor<float> c_m_n_host_fp32(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<float> c_m_n_device_fp32( Tensor<float> c_m_n_device_fp32(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
a_m_k_bf16.GenerateTensorValue(GeneratorTensor_3<BF16>{-0.5, 0.5}); a_m_k_bf16.GenerateTensorValue(GeneratorTensor_3<BF16>{-0.5, 0.5});
b_k_n_bf16.GenerateTensorValue(GeneratorTensor_3<BF16>{-0.5, 0.5}); b_k_n_bf16.GenerateTensorValue(GeneratorTensor_3<BF16>{-0.5, 0.5});
bf16_to_f32_(a_m_k_bf16, a_m_k_fp32); bf16_to_f32_(a_m_k_bf16, a_m_k_fp32);
bf16_to_f32_(b_k_n_bf16, b_k_n_fp32); bf16_to_f32_(b_k_n_bf16, b_k_n_fp32);
return std::make_tuple(a_m_k_bf16, return std::make_tuple(a_m_k_bf16,
b_k_n_bf16, b_k_n_bf16,
c_m_n_device_bf16, c_m_n_device_bf16,
a_m_k_fp32, a_m_k_fp32,
b_k_n_fp32, b_k_n_fp32,
c_m_n_host_fp32, c_m_n_host_fp32,
c_m_n_device_fp32); c_m_n_device_fp32);
} }
auto operator()(DeviceGemmPtr_& gemmPtr) auto operator()(DeviceGemmPtr_& gemmPtr)
{ {
// Arrange // Arrange
ck::gemm_util::GemmParams params; ck::gemm_util::GemmParams params;
params.M = 1024; params.M = 1024;
params.N = 1024; params.N = 1024;
params.K = 1024; params.K = 1024;
params.StrideA = 1024; params.StrideA = 1024;
params.StrideB = 1024; params.StrideB = 1024;
params.StrideC = 1024; params.StrideC = 1024;
auto host_tensors = PrepareGemmTensorBF16(params); auto host_tensors = PrepareGemmTensorBF16(params);
const Tensor<BF16>& a_bf16 = std::get<0>(host_tensors); const Tensor<BF16>& a_bf16 = std::get<0>(host_tensors);
const Tensor<BF16>& b_bf16 = std::get<1>(host_tensors); const Tensor<BF16>& b_bf16 = std::get<1>(host_tensors);
Tensor<BF16>& c_device_bf16 = std::get<2>(host_tensors); Tensor<BF16>& c_device_bf16 = std::get<2>(host_tensors);
Tensor<float>& a_fp32 = std::get<3>(host_tensors); Tensor<float>& a_fp32 = std::get<3>(host_tensors);
Tensor<float>& b_fp32 = std::get<4>(host_tensors); Tensor<float>& b_fp32 = std::get<4>(host_tensors);
Tensor<float>& c_host_fp32 = std::get<5>(host_tensors); Tensor<float>& c_host_fp32 = std::get<5>(host_tensors);
Tensor<float>& c_device_fp32 = std::get<6>(host_tensors); Tensor<float>& c_device_fp32 = std::get<6>(host_tensors);
auto a_element_op = AElementwiseOperation{}; auto a_element_op = AElementwiseOperation{};
auto b_element_op = BElementwiseOperation{}; auto b_element_op = BElementwiseOperation{};
auto c_element_op = CElementwiseOperation{}; auto c_element_op = CElementwiseOperation{};
// use fp32 host kernel to verify bf16 device kernel // use fp32 host kernel to verify bf16 device kernel
using ReferenceGemmInstance = using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemm<float, ck::tensor_operation::host::ReferenceGemm<float,
float, float,
float, float,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation>; CElementwiseOperation>;
ck::gemm_util::RunHostGEMM<ReferenceGemmInstance>( ck::gemm_util::RunHostGEMM<ReferenceGemmInstance>(
a_fp32, b_fp32, c_host_fp32, a_element_op, b_element_op, c_element_op); a_fp32, b_fp32, c_host_fp32, a_element_op, b_element_op, c_element_op);
// Act // Act
ck::gemm_util::RunDeviceGEMM(gemmPtr, ck::gemm_util::RunDeviceGEMM(gemmPtr,
params, params,
a_bf16, a_bf16,
b_bf16, b_bf16,
c_device_bf16, c_device_bf16,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); c_element_op);
bf16_to_f32_(c_device_bf16, c_device_fp32); bf16_to_f32_(c_device_bf16, c_device_fp32);
// Assert // Assert
bool res = ck::utils::check_err( bool res = ck::utils::check_err(
c_device_fp32.mData, c_host_fp32.mData, "Error: incorrect results!", 1e-2f, 1e-3f); c_device_fp32.mData, c_host_fp32.mData, "Error: incorrect results!", 1e-2f, 1e-3f);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res; return res;
}; };
}; };
} // namespace gemm_util } // namespace gemm_util
} // namespace ck } // namespace ck
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment