// SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" using ADataType = ck::half_t; using BDataType = int8_t; using CDataType = ck::half_t; using AccDataType = float; using ALayout = Row; using BLayout = Row; using CLayout = Row; using AElementOp = PassThrough; using BElementOp = PassThrough; using CElementOp = PassThrough; // static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; // clang-format off using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl // ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| // ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | Order| | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 1, 512, 2, 4, 1, 8, 1, S<1, 1, 1, 4>, S<2, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<1, 2, 0, 3>, 2, 8, S<0, 1, 2, 3, 4, 5>, 5, 1>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }