Commit ebfa3921 authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/fix_test' into add_mfma_f64

parents 58f4d821 579e8e76
......@@ -106,7 +106,6 @@ template <typename DeviceGemmPtr_,
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
......@@ -140,17 +139,10 @@ struct TestGemm
Tensor<CDataType> c_m_n_device_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
auto f_generate_tensor_value = [](auto& desc, auto type) {
auto f_generate_tensor_value = [](auto& tensor, auto type) {
using dataType = decltype(type);
if(std::is_same<dataType, int8_t>::value || std::is_same<dataType, double>::value)
{
desc.GenerateTensorValue(GeneratorTensor_2<dataType>{-5, 5});
}
else
{
desc.GenerateTensorValue(GeneratorTensor_3<dataType>{-0.5, 0.5});
}
tensor.GenerateTensorValue(GeneratorTensor_2<dataType>{-5, 5});
};
f_generate_tensor_value(a_m_k, ADataType{});
......@@ -161,7 +153,6 @@ struct TestGemm
auto operator()(DeviceGemmPtr_& gemmPtr)
{
std::cout << "data type: " << typeid(ADataType{}).name() << std::endl;
std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name
<< ", CLayout = " << CLayout{}.name << std::endl;
std::cout << gemmPtr->GetTypeString() << std::endl;
......@@ -190,7 +181,6 @@ struct TestGemm
ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
......@@ -203,12 +193,7 @@ struct TestGemm
// Assert
bool res = false;
if(std::is_same<CDataType, double>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else if(std::is_same<CDataType, float>::value)
if(std::is_same<CDataType, float>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
......@@ -314,7 +299,6 @@ struct TestGemmBF16
// use fp32 host kernel to verify bf16 device kernel
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemm<float,
float,
float,
float,
AElementwiseOperation,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment