Commit 9e8cb769 authored by qinletao's avatar qinletao
Browse files

add acc data type

parent ebfa3921
...@@ -106,6 +106,7 @@ template <typename DeviceGemmPtr_, ...@@ -106,6 +106,7 @@ template <typename DeviceGemmPtr_,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AccDataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
...@@ -153,6 +154,7 @@ struct TestGemm ...@@ -153,6 +154,7 @@ struct TestGemm
auto operator()(DeviceGemmPtr_& gemmPtr) auto operator()(DeviceGemmPtr_& gemmPtr)
{ {
std::cout << "Data type: " << typeid(CDataType{}).name() << std::endl;
std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name
<< ", CLayout = " << CLayout{}.name << std::endl; << ", CLayout = " << CLayout{}.name << std::endl;
std::cout << gemmPtr->GetTypeString() << std::endl; std::cout << gemmPtr->GetTypeString() << std::endl;
...@@ -181,6 +183,7 @@ struct TestGemm ...@@ -181,6 +183,7 @@ struct TestGemm
ck::tensor_operation::host::ReferenceGemm<ADataType, ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation>; CElementwiseOperation>;
...@@ -193,21 +196,8 @@ struct TestGemm ...@@ -193,21 +196,8 @@ struct TestGemm
// Assert // Assert
bool res = false; bool res = false;
if(std::is_same<CDataType, float>::value) res = ck::utils::check_err(c_device.mData, c_host.mData);
{ std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else if(std::is_same<CDataType, ck::half_t>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else if(std::is_same<CDataType, int8_t>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
return res; return res;
} }
...@@ -299,6 +289,7 @@ struct TestGemmBF16 ...@@ -299,6 +289,7 @@ struct TestGemmBF16
// use fp32 host kernel to verify bf16 device kernel // use fp32 host kernel to verify bf16 device kernel
using ReferenceGemmInstance = using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemm<float, ck::tensor_operation::host::ReferenceGemm<float,
float,
float, float,
float, float,
AElementwiseOperation, AElementwiseOperation,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment