Commit ebfa3921 authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/fix_test' into add_mfma_f64

parents 58f4d821 579e8e76
...@@ -106,7 +106,6 @@ template <typename DeviceGemmPtr_, ...@@ -106,7 +106,6 @@ template <typename DeviceGemmPtr_,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AccDataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
...@@ -140,17 +139,10 @@ struct TestGemm ...@@ -140,17 +139,10 @@ struct TestGemm
Tensor<CDataType> c_m_n_device_result( Tensor<CDataType> c_m_n_device_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
auto f_generate_tensor_value = [](auto& desc, auto type) { auto f_generate_tensor_value = [](auto& tensor, auto type) {
using dataType = decltype(type); using dataType = decltype(type);
if(std::is_same<dataType, int8_t>::value || std::is_same<dataType, double>::value) tensor.GenerateTensorValue(GeneratorTensor_2<dataType>{-5, 5});
{
desc.GenerateTensorValue(GeneratorTensor_2<dataType>{-5, 5});
}
else
{
desc.GenerateTensorValue(GeneratorTensor_3<dataType>{-0.5, 0.5});
}
}; };
f_generate_tensor_value(a_m_k, ADataType{}); f_generate_tensor_value(a_m_k, ADataType{});
...@@ -161,7 +153,6 @@ struct TestGemm ...@@ -161,7 +153,6 @@ struct TestGemm
auto operator()(DeviceGemmPtr_& gemmPtr) auto operator()(DeviceGemmPtr_& gemmPtr)
{ {
std::cout << "data type: " << typeid(ADataType{}).name() << std::endl;
std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name
<< ", CLayout = " << CLayout{}.name << std::endl; << ", CLayout = " << CLayout{}.name << std::endl;
std::cout << gemmPtr->GetTypeString() << std::endl; std::cout << gemmPtr->GetTypeString() << std::endl;
...@@ -190,7 +181,6 @@ struct TestGemm ...@@ -190,7 +181,6 @@ struct TestGemm
ck::tensor_operation::host::ReferenceGemm<ADataType, ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
CDataType, CDataType,
AccDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation>; CElementwiseOperation>;
...@@ -203,12 +193,7 @@ struct TestGemm ...@@ -203,12 +193,7 @@ struct TestGemm
// Assert // Assert
bool res = false; bool res = false;
if(std::is_same<CDataType, double>::value) if(std::is_same<CDataType, float>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else if(std::is_same<CDataType, float>::value)
{ {
res = ck::utils::check_err(c_device.mData, c_host.mData); res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
...@@ -314,7 +299,6 @@ struct TestGemmBF16 ...@@ -314,7 +299,6 @@ struct TestGemmBF16
// use fp32 host kernel to verify bf16 device kernel // use fp32 host kernel to verify bf16 device kernel
using ReferenceGemmInstance = using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemm<float, ck::tensor_operation::host::ReferenceGemm<float,
float,
float, float,
float, float,
AElementwiseOperation, AElementwiseOperation,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment