Unverified Commit 485ea46a authored by rocking5566's avatar rocking5566 Committed by GitHub
Browse files

Gemm_c_shuffle (4 layouts) X (fp32 bf16 int8) (#131)



* [What] Separate fixpoint gemm from gemm example
[Why] let example of gemm_int8 be pure gemm.
[What]
1. Add gemm_requant_relu_requant,
2. Let CDataType be int32 in pure gemm, because no one use int8 CDataType. It is also part of gemm_requant_relu_requant

* Fix path

* Revise cmakelist due to merge develop

* Add gemm fp16 test

* Extract PrepareGemmTensor

* Extract TestGemm

* Add test for different layout

* Add 4 layouts of shuffle version of fp32

* Add 4 layouts of shuffle version of int8

* Add 4 layouts of shuffle version of bf16

* replace all DeviceGemmPtr_ with DeviceGemmNoOpPtr to fit naming convension

* Add test for non-shuffle verstion of gemm

* Fix typo

* Print kernel information

* Add rest of the fp32 kernel to the test

* 1. Add rest of the fp16 device iop.
2. Mark the invalid device operation
Co-authored-by: default avatarrocking <chunylai@amd.com>
parent b51808d7
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceGemmPtr_ = using DeviceGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>; ck::tensor_operation::element_wise::PassThrough>;
...@@ -32,106 +32,122 @@ namespace ck { ...@@ -32,106 +32,122 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_gemm_instance { namespace device_gemm_instance {
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmPtr_>&); void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
} void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
namespace { int main()
{
using ADataType = float;
using BDataType = float;
using CDataType = float;
using ADataType = float; using RowMajor = ck::tensor_layout::gemm::RowMajor;
using BDataType = float; using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
using CDataType = float;
using AccDataType = float;
using ALayout = ck::tensor_layout::gemm::RowMajor; bool res = true;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; std::vector<DeviceGemmNoOpPtr> gemmPtrs;
using CLayout = ck::tensor_layout::gemm::RowMajor; ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemmPtrs);
auto PrepareGemmTensor(const ck::gemm_util::GemmParams& params) for(auto& gemmPtr : gemmPtrs)
{ {
auto f_host_tensor_descriptor = res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { ADataType,
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value) BDataType,
{ CDataType,
return HostTensorDescriptor(std::vector<std::size_t>({row, col}), ColumnMajor,
std::vector<std::size_t>({stride, 1})); RowMajor,
} RowMajor,
else PassThrough,
{ PassThrough,
return HostTensorDescriptor(std::vector<std::size_t>({row, col}), PassThrough>{}(gemmPtr);
std::vector<std::size_t>({1, stride})); }
}
};
Tensor<ADataType> a_m_k(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
Tensor<BDataType> b_k_n(
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-0.5, 0.5});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
return std::make_tuple(a_m_k, b_k_n, c_m_n_host_result, c_m_n_device_result);
}
bool TestGemm(DeviceGemmPtr_& gemmPtr) gemmPtrs.clear();
{ ck::tensor_operation::device::device_gemm_instance::
// Arrange add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemmPtrs);
ck::gemm_util::GemmParams params; ck::tensor_operation::device::device_gemm_instance::
params.M = 1024; add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemmPtrs);
params.N = 1024; ck::tensor_operation::device::device_gemm_instance::
params.K = 1024; add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemmPtrs);
params.StrideA = 1024;
params.StrideB = 1024;
params.StrideC = 1024;
auto host_tensors = PrepareGemmTensor(params);
const Tensor<ADataType>& a = std::get<0>(host_tensors);
const Tensor<BDataType>& b = std::get<1>(host_tensors);
Tensor<CDataType>& c_host = std::get<2>(host_tensors);
Tensor<CDataType>& c_device = std::get<3>(host_tensors);
auto a_element_op = PassThrough{};
auto b_element_op = PassThrough{};
auto c_element_op = PassThrough{};
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
ck::gemm_util::RunHostGEMM<ReferenceGemmInstance>(
a, b, c_host, a_element_op, b_element_op, c_element_op);
// Act
ck::gemm_util::RunDeviceGEMM(
gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op);
// Assert
bool res = test_util::check_err(
c_device.mData, c_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res;
}
} // anonymous namespace for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
int main() gemmPtrs.clear();
{
std::vector<DeviceGemmPtr_> gemmPtrs;
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemmPtrs); add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemmPtrs);
bool res = true; for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
RowMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs) for(auto& gemmPtr : gemmPtrs)
{ {
res &= TestGemm(gemmPtr); res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
RowMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
} }
std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceGemmPtr_ = using DeviceGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>; ck::tensor_operation::element_wise::PassThrough>;
...@@ -32,105 +32,96 @@ namespace ck { ...@@ -32,105 +32,96 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_gemm_instance { namespace device_gemm_instance {
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(std::vector<DeviceGemmPtr_>&); void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
} }
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
namespace { int main()
{
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int8_t;
using ADataType = int8_t; using RowMajor = ck::tensor_layout::gemm::RowMajor;
using BDataType = int8_t; using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
using CDataType = int8_t;
using AccDataType = int32_t;
using ALayout = ck::tensor_layout::gemm::RowMajor; std::vector<DeviceGemmNoOpPtr> gemmPtrs;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; bool res = true;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto PrepareGemmTensor(const ck::gemm_util::GemmParams& params) ck::tensor_operation::device::device_gemm_instance::
{ add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(gemmPtrs);
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
Tensor<ADataType> a_m_k(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
Tensor<BDataType> b_k_n(
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
return std::make_tuple(a_m_k, b_k_n, c_m_n_host_result, c_m_n_device_result);
}
bool TestGemm(DeviceGemmPtr_& gemmPtr) for(auto& gemmPtr : gemmPtrs)
{ {
// Arrange res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ck::gemm_util::GemmParams params; ADataType,
params.M = 1024; BDataType,
params.N = 1024; CDataType,
params.K = 1024; ColumnMajor,
params.StrideA = 1024; RowMajor,
params.StrideB = 1024; RowMajor,
params.StrideC = 1024; PassThrough,
PassThrough,
auto host_tensors = PrepareGemmTensor(params); PassThrough>{}(gemmPtr);
const Tensor<ADataType>& a = std::get<0>(host_tensors); }
const Tensor<BDataType>& b = std::get<1>(host_tensors);
Tensor<CDataType>& c_host = std::get<2>(host_tensors);
Tensor<CDataType>& c_device = std::get<3>(host_tensors);
auto a_element_op = PassThrough{};
auto b_element_op = PassThrough{};
auto c_element_op = PassThrough{};
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
ck::gemm_util::RunHostGEMM<ReferenceGemmInstance>(
a, b, c_host, a_element_op, b_element_op, c_element_op);
// Act
ck::gemm_util::RunDeviceGEMM(
gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op);
// Assert
bool res = test_util::check_err(c_device.mData, c_host.mData, "Error: incorrect results!");
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res;
}
} // anonymous namespace gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(gemmPtrs);
int main() for(auto& gemmPtr : gemmPtrs)
{ {
std::vector<DeviceGemmPtr_> gemmPtrs; res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemmPtrs); add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(gemmPtrs);
bool res = true; for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
RowMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs) for(auto& gemmPtr : gemmPtrs)
{ {
res &= TestGemm(gemmPtr); res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
RowMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
} }
std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
......
...@@ -4,6 +4,10 @@ ...@@ -4,6 +4,10 @@
#include "config.hpp" #include "config.hpp"
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "reference_gemm.hpp"
#include "tensor_layout.hpp"
#include "test_util.hpp"
namespace ck { namespace ck {
namespace gemm_util { namespace gemm_util {
...@@ -98,6 +102,243 @@ void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, ...@@ -98,6 +102,243 @@ void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
c_m_n_device_buf.FromDevice(C.mData.data()); c_m_n_device_buf.FromDevice(C.mData.data());
} }
template <typename DeviceGemmPtr_,
typename ADataType,
typename BDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct TestGemm
{
auto PrepareGemmTensor(const ck::gemm_util::GemmParams& params)
{
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
Tensor<ADataType> a_m_k(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
Tensor<BDataType> b_k_n(
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
auto f_generate_tensor_value = [](auto desc, auto type) {
using dataType = decltype(type);
if(std::is_same<dataType, int8_t>::value)
{
desc.GenerateTensorValue(GeneratorTensor_2<int8_t>{-5, 5});
}
else
{
desc.GenerateTensorValue(GeneratorTensor_3<dataType>{-0.5, 0.5});
}
};
f_generate_tensor_value(a_m_k, ADataType{});
f_generate_tensor_value(b_k_n, BDataType{});
return std::make_tuple(a_m_k, b_k_n, c_m_n_host_result, c_m_n_device_result);
}
auto operator()(DeviceGemmPtr_& gemmPtr)
{
std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name
<< ", CLayout = " << CLayout{}.name << std::endl;
std::cout << gemmPtr->GetTypeString() << std::endl;
// Arrange
ck::gemm_util::GemmParams params;
params.M = 1024;
params.N = 1024;
params.K = 1024;
params.StrideA = 1024;
params.StrideB = 1024;
params.StrideC = 1024;
auto host_tensors = PrepareGemmTensor(params);
const Tensor<ADataType>& a = std::get<0>(host_tensors);
const Tensor<BDataType>& b = std::get<1>(host_tensors);
Tensor<CDataType>& c_host = std::get<2>(host_tensors);
Tensor<CDataType>& c_device = std::get<3>(host_tensors);
auto a_element_op = AElementwiseOperation{};
auto b_element_op = BElementwiseOperation{};
auto c_element_op = CElementwiseOperation{};
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
ck::gemm_util::RunHostGEMM<ReferenceGemmInstance>(
a, b, c_host, a_element_op, b_element_op, c_element_op);
// Act
ck::gemm_util::RunDeviceGEMM(
gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op);
// Assert
bool res = false;
if(std::is_same<CDataType, float>::value)
{
res = test_util::check_err(c_device.mData, c_host.mData, "Error: incorrect results!");
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else if(std::is_same<CDataType, ck::half_t>::value)
{
res = test_util::check_err(c_device.mData, c_host.mData, "Error: incorrect results!");
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else if(std::is_same<CDataType, int8_t>::value)
{
res = test_util::check_err(c_device.mData, c_host.mData, "Error: incorrect results!");
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
return res;
}
};
template <typename DeviceGemmPtr_,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct TestGemmBF16
{
using BF16 = ck::bhalf_t;
auto PrepareGemmTensorBF16(const ck::gemm_util::GemmParams& params)
{
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
// use fp32 host kernel to verify bf16 device kernel
Tensor<BF16> a_m_k_bf16(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
Tensor<BF16> b_k_n_bf16(
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
Tensor<BF16> c_m_n_device_bf16(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<float> a_m_k_fp32(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
Tensor<float> b_k_n_fp32(
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
Tensor<float> c_m_n_host_fp32(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<float> c_m_n_device_fp32(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
a_m_k_bf16.GenerateTensorValue(GeneratorTensor_3<BF16>{-0.5, 0.5});
b_k_n_bf16.GenerateTensorValue(GeneratorTensor_3<BF16>{-0.5, 0.5});
bf16_to_f32_(a_m_k_bf16, a_m_k_fp32);
bf16_to_f32_(b_k_n_bf16, b_k_n_fp32);
return std::make_tuple(a_m_k_bf16,
b_k_n_bf16,
c_m_n_device_bf16,
a_m_k_fp32,
b_k_n_fp32,
c_m_n_host_fp32,
c_m_n_device_fp32);
}
auto operator()(DeviceGemmPtr_& gemmPtr)
{
// Arrange
ck::gemm_util::GemmParams params;
params.M = 1024;
params.N = 1024;
params.K = 1024;
params.StrideA = 1024;
params.StrideB = 1024;
params.StrideC = 1024;
auto host_tensors = PrepareGemmTensorBF16(params);
const Tensor<BF16>& a_bf16 = std::get<0>(host_tensors);
const Tensor<BF16>& b_bf16 = std::get<1>(host_tensors);
Tensor<BF16>& c_device_bf16 = std::get<2>(host_tensors);
Tensor<float>& a_fp32 = std::get<3>(host_tensors);
Tensor<float>& b_fp32 = std::get<4>(host_tensors);
Tensor<float>& c_host_fp32 = std::get<5>(host_tensors);
Tensor<float>& c_device_fp32 = std::get<6>(host_tensors);
auto a_element_op = AElementwiseOperation{};
auto b_element_op = BElementwiseOperation{};
auto c_element_op = CElementwiseOperation{};
// use fp32 host kernel to verify bf16 device kernel
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemm<float,
float,
float,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
ck::gemm_util::RunHostGEMM<ReferenceGemmInstance>(
a_fp32, b_fp32, c_host_fp32, a_element_op, b_element_op, c_element_op);
// Act
ck::gemm_util::RunDeviceGEMM(gemmPtr,
params,
a_bf16,
b_bf16,
c_device_bf16,
a_element_op,
b_element_op,
c_element_op);
bf16_to_f32_(c_device_bf16, c_device_fp32);
// Assert
bool res = test_util::check_err(
c_device_fp32.mData, c_host_fp32.mData, "Error: incorrect results!", 1e-2f, 1e-3f);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res;
};
};
} // namespace gemm_util } // namespace gemm_util
} // namespace ck } // namespace ck
#endif #endif
...@@ -54,6 +54,49 @@ check_err(const std::vector<T>& out, ...@@ -54,6 +54,49 @@ check_err(const std::vector<T>& out,
return res; return res;
} }
bool check_err(const std::vector<_Float16>& out,
const std::vector<_Float16>& ref,
const std::string& msg,
_Float16 rtol = static_cast<_Float16>(1e-3f),
_Float16 atol = static_cast<_Float16>(1e-3f))
{
if(out.size() != ref.size())
{
std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl
<< msg << std::endl;
return false;
}
bool res{true};
int err_count = 0;
double err = 0;
double max_err = std::numeric_limits<_Float16>::min();
for(std::size_t i = 0; i < ref.size(); ++i)
{
double out_ = double(out[i]);
double ref_ = double(ref[i]);
err = std::abs(out_ - ref_);
if(err > atol + rtol * std::abs(ref_) || !std::isfinite(out_) || !std::isfinite(ref_))
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
{
std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref["
<< i << "]: " << out_ << "!=" << ref_ << std::endl
<< msg << std::endl;
}
res = false;
}
}
if(!res)
{
std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
}
return res;
}
template <typename T> template <typename T>
typename std::enable_if<std::is_integral<T>::value, bool>::type check_err( typename std::enable_if<std::is_integral<T>::value, bool>::type check_err(
const std::vector<T>& out, const std::vector<T>& ref, const std::string& msg, T = 0, T = 0) const std::vector<T>& out, const std::vector<T>& ref, const std::string& msg, T = 0, T = 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment