"vscode:/vscode.git/clone" did not exist on "462a79d39ad278090fbe5fc723d5a2c4d22185b9"
Commit ebe8b7d1 authored by Anthony Chang's avatar Anthony Chang
Browse files

simplify gemm test

parent 37d83d7d
...@@ -24,56 +24,11 @@ ...@@ -24,56 +24,11 @@
#include "test/gemm/gemm_util.hpp" #include "test/gemm/gemm_util.hpp"
int main() using ADataType = ck::bhalf_t;
{ using BDataType = ck::bhalf_t;
using ADataType = ck::bhalf_t; using CDataType = ck::bhalf_t;
using BDataType = ck::bhalf_t; using AccDataType = float;
using CDataType = ck::bhalf_t;
using AccDataType = float;
using Row = ck::tensor_layout::gemm::RowMajor; #include "run_gemm_test.inc"
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; int main() { return run_gemm_test(); }
auto test = [&](auto a_layout, auto b_layout, auto c_layout) {
bool pass = true;
using DeviceOp = ck::tensor_operation::device::DeviceGemm<decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>;
const auto gemmPtrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
for(auto& gemmPtr : gemmPtrs)
{
pass &= ck::gemm_util::TestGemm<std::unique_ptr<DeviceOp>,
ADataType,
BDataType,
CDataType,
AccDataType,
decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
return pass;
};
bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) &&
test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{});
std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
return pass ? 0 : 1;
}
...@@ -24,56 +24,11 @@ ...@@ -24,56 +24,11 @@
#include "test/gemm/gemm_util.hpp" #include "test/gemm/gemm_util.hpp"
int main() using ADataType = ck::half_t;
{ using BDataType = ck::half_t;
using ADataType = ck::half_t; using CDataType = ck::half_t;
using BDataType = ck::half_t; using AccDataType = float;
using CDataType = ck::half_t;
using AccDataType = float;
using Row = ck::tensor_layout::gemm::RowMajor; #include "run_gemm_test.inc"
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; int main() { return run_gemm_test(); }
auto test = [&](auto a_layout, auto b_layout, auto c_layout) {
bool pass = true;
using DeviceOp = ck::tensor_operation::device::DeviceGemm<decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>;
const auto gemmPtrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
for(auto& gemmPtr : gemmPtrs)
{
pass &= ck::gemm_util::TestGemm<std::unique_ptr<DeviceOp>,
ADataType,
BDataType,
CDataType,
AccDataType,
decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
return pass;
};
bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) &&
test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{});
std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
return pass ? 0 : 1;
}
...@@ -24,56 +24,11 @@ ...@@ -24,56 +24,11 @@
#include "test/gemm/gemm_util.hpp" #include "test/gemm/gemm_util.hpp"
int main() using ADataType = float;
{ using BDataType = float;
using ADataType = float; using CDataType = float;
using BDataType = float; using AccDataType = float;
using CDataType = float;
using AccDataType = float;
using Row = ck::tensor_layout::gemm::RowMajor; #include "run_gemm_test.inc"
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; int main() { return run_gemm_test(); }
auto test = [&](auto a_layout, auto b_layout, auto c_layout) {
bool pass = true;
using DeviceOp = ck::tensor_operation::device::DeviceGemm<decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>;
const auto gemmPtrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
for(auto& gemmPtr : gemmPtrs)
{
pass &= ck::gemm_util::TestGemm<std::unique_ptr<DeviceOp>,
ADataType,
BDataType,
CDataType,
AccDataType,
decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
return pass;
};
bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) &&
test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{});
std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
return pass ? 0 : 1;
}
...@@ -24,56 +24,11 @@ ...@@ -24,56 +24,11 @@
#include "test/gemm/gemm_util.hpp" #include "test/gemm/gemm_util.hpp"
int main() using ADataType = double;
{ using BDataType = double;
using ADataType = double; using CDataType = double;
using BDataType = double; using AccDataType = double;
using CDataType = double;
using AccDataType = double;
using Row = ck::tensor_layout::gemm::RowMajor; #include "run_gemm_test.inc"
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; int main() { return run_gemm_test(); }
auto test = [&](auto a_layout, auto b_layout, auto c_layout) {
bool pass = true;
using DeviceOp = ck::tensor_operation::device::DeviceGemm<decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>;
const auto gemmPtrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
for(auto& gemmPtr : gemmPtrs)
{
pass &= ck::gemm_util::TestGemm<std::unique_ptr<DeviceOp>,
ADataType,
BDataType,
CDataType,
AccDataType,
decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
return pass;
};
bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) &&
test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{});
std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
return pass ? 0 : 1;
}
...@@ -24,56 +24,11 @@ ...@@ -24,56 +24,11 @@
#include "test/gemm/gemm_util.hpp" #include "test/gemm/gemm_util.hpp"
int main() using ADataType = int8_t;
{ using BDataType = int8_t;
using ADataType = int8_t; using CDataType = int8_t;
using BDataType = int8_t; using AccDataType = int32_t;
using CDataType = int8_t;
using AccDataType = int32_t;
using Row = ck::tensor_layout::gemm::RowMajor; #include "run_gemm_test.inc"
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; int main() { return run_gemm_test(); }
auto test = [&](auto a_layout, auto b_layout, auto c_layout) {
bool pass = true;
using DeviceOp = ck::tensor_operation::device::DeviceGemm<decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>;
const auto gemmPtrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
for(auto& gemmPtr : gemmPtrs)
{
pass &= ck::gemm_util::TestGemm<std::unique_ptr<DeviceOp>,
ADataType,
BDataType,
CDataType,
AccDataType,
decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
return pass;
};
bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) &&
test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{});
std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
return pass ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
int run_gemm_test()
{
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
auto test = [&](auto a_layout, auto b_layout, auto c_layout) {
bool pass = true;
using DeviceOp = ck::tensor_operation::device::DeviceGemm<decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>;
const auto gemmPtrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
for(auto& gemmPtr : gemmPtrs)
{
pass &= ck::gemm_util::TestGemm<AccDataType>{}(gemmPtr.get());
}
return pass;
};
bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) &&
test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{});
std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
return pass ? 0 : 1;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment