// SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #include #include #include #include #include #include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "test/wmma_op/wmma_op_util.hpp" template bool run_test() { using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; using PassThrough = ck::tensor_operation::element_wise::PassThrough; bool pass = true; const auto matmul_default = ck::wmma_op_util::matmul; const auto matmul_swizzle_a = ck::wmma_op_util::matmul_swizzle_a; const auto wmma_kernel_container = std::make_tuple(matmul_default, matmul_swizzle_a); ck::static_for<0, 2, 1>{}([&](auto i) { pass &= ck::wmma_op_util::TestWmma{}>(wmma_kernel_container)), SrcType, SrcType, DstType, GPUAccType, CPUAccType, decltype(Row{}), decltype(Col{}), decltype(Row{}), PassThrough, PassThrough, PassThrough, AccNum>{}(std::get{}>(wmma_kernel_container)); }); return pass ? 1 : 0; } int main(int, char*[]) { bool pass = true; // clang-format off // |SrcType |DstType |GPUAccType |CPUAccType |AccNum pass &= run_test(); pass &= run_test(); pass &= run_test(); pass &= run_test(); pass &= run_test(); // clang-format on std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl; return pass ? 0 : 1; }