"src/turbomind/models/llama/LlamaContextDecoder.cc" did not exist on "720fc533da804ac3f46ee938864403e51fcd9fa7"
Commit 49facb91 authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

files for gemv and tall and skinny gemm examples and corresponding entries to ckprofiler

parent 98fd41f5
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple<
// clang-format off
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer | ABlockTransfer| ABlockTransfer | BBlockTransfer| BThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|SrcVectorTensorLengths| SrcVectorTensor|DstVectorTensorLengths| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | KBatch_K0_M0_M1_K1| KBatch_K0_M0_M1_K1| ArrangeOrder| Order| KBatch_K0_M0_M1_K1 | ContiguousDimOrder| KBatch_K0_M0_M1_K1 | Order| | | Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
///< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, B, M1, B*N1, K0, K1, M1, N1, 1, S<1,1, 1, 1, K1>, S<1,K0, 1,M1, 1>,S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, 4, K1, S<0, 1, 2, 3, 4, 5>, 5, N1>;
//M1 is always tied to 16
//N1=2
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 1, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 1, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 1, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 2, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 2, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 2, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 3, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 3, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 3, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 4, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 4, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 4, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 5, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 5, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 5, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 6, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 6, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 6, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 7, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 7, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 7, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 8, 2, 16, 2, 1, S<1,1, 1, 1, 2>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 8, 4, 16, 2, 1, S<1,1, 1, 1, 4>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 128, 8, 8, 16, 2, 1, S<1,1, 1, 1, 8>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 2>,
// //N1=4
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 1, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 1, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 1, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 2, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 2, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 2, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 3, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 3, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 3, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 4, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 4, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 4, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 5, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 5, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 5, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 6, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 6, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 6, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 7, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 7, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 7, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 8, 2, 16, 4, 1, S<1,1, 1, 1, 2>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 8, 4, 16, 4, 1, S<1,1, 1, 1, 4>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 256, 8, 8, 16, 4, 1, S<1,1, 1, 1, 8>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 4>,
// //N1=8
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 1, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 1, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 1, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,1, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 2, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 2, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 2, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,2, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 3, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 3, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 3, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,3, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 4, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 4, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
ck::tensor_operation::device::deviceTsmmDl
< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 4, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,4, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 5, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 5, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 5, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,5, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 6, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 6, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 6, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,6, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 7, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 7, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 7, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,7, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 8, 2, 16, 8, 1, S<1,1, 1, 1, 2>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 4, 2, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 8, 4, 16, 8, 1, S<1,1, 1, 1, 4>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, S<1,1, 1, 1, 4>, S<0,1,2,3,4>, 4, 4, S<0, 1, 2, 3, 4, 5>, 5, 8>,
// ck::tensor_operation::device::deviceTsmmDl
///< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 512, 8, 8, 16, 8, 1, S<1,1, 1, 1, 8>, S<1,8, 1,16, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, S<1,1, 1, 1, 8>, S<0,1,2,3,4>, 4, 8, S<0, 1, 2, 3, 4, 5>, 5, 8>
// clang-format on
>;
void add_device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceTsmm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances, device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_nk_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include <iostream>
#include <typeinfo>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_tall_and_skinny_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemv_splitk.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace ck {
namespace profiler {
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
bool profile_gemv_splitk_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideC,
int KBatch)
{
bool pass = true;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-1, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-1, 2});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
}
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
using DeviceOp = ck::tensor_operation::device::DeviceTsmm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// Run reference GEMM
if(do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
}
std::string best_op_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
float best_kbatch = 0;
// profile device GEMM instances
for(auto& op_ptr : op_ptrs)
{
std::vector<int> kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 36, 40, 60,
64, 72, 80, 88, 96, 128, 144, 160, 176, 192, 256};
if(KBatch > 0)
{
kbatch_list = {KBatch};
}
for(std::size_t i = 0; i < kbatch_list.size(); i++)
{
auto kbatch_curr = kbatch_list[i];
auto argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
kbatch_curr);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
if(do_verification)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(
std::cout << "c_host : ", c_m_n_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_device: ", c_m_n_device_result.mData, ",")
<< std::endl;
}
}
std::string op_name = op_ptr->GetTypeString();
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch "
<< kbatch_curr << std::endl;
#if defined CK_ENABLE_FP8
// set softer tolerances for fp8
if constexpr(is_same_v<ADataType, f8_t> || is_same_v<BDataType, f8_t> ||
is_same_v<CDataType, f8_t>)
{
std::string msg = "Error: Incorrect results!";
double rtol = 1e-1;
double atol = 1e-1;
pass = pass & ck::utils::check_err(
c_m_n_device_result, c_m_n_host_result, msg, rtol, atol);
}
else
{
#endif
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
#if defined CK_ENABLE_FP8
}
#endif
if(tflops > best_tflops)
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
}
}
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem"
<< std::endl;
}
}
}
if constexpr(is_same<CDataType, float>::value)
{
std::cout << "Best Perf for datatype = f32";
}
else if constexpr(is_same<CDataType, half_t>::value)
{
std::cout << "Best Perf for datatype = f16";
}
else if constexpr(is_same<CDataType, bhalf_t>::value)
{
std::cout << "Best Perf for datatype = bf16";
}
else if constexpr(is_same<CDataType, int8_t>::value)
{
std::cout << "Best Perf for datatype = int8";
}
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value)
{
std::cout << " ALayout = RowMajor";
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value)
{
std::cout << " ALayout = ColumnMajor";
}
if constexpr(is_same<BLayout, tensor_layout::gemm::RowMajor>::value)
{
std::cout << " BLayout = RowMajor";
}
else if constexpr(is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value)
{
std::cout << " BLayout = ColumnMajor";
}
std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << best_kbatch
<< " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec
<< " GB/s, " << best_op_name << std::endl;
return pass;
}
} // namespace profiler
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include <iostream>
#include <typeinfo>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_tall_and_skinny_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/tall_and_skinny_gemm_splitk.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace ck {
namespace profiler {
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
bool profile_tall_and_skinny_gemm_splitk_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideC,
int KBatch)
{
bool pass = true;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-1, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-1, 2});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
}
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
using DeviceOp = ck::tensor_operation::device::DeviceTsmm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// Run reference GEMM
if(do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
}
std::string best_op_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
float best_kbatch = 0;
// profile device GEMM instances
for(auto& op_ptr : op_ptrs)
{
std::vector<int> kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 36, 40, 60,
64, 72, 80, 88, 96, 128, 144, 160, 176, 192, 256};
if(KBatch > 0)
{
kbatch_list = {KBatch};
}
for(std::size_t i = 0; i < kbatch_list.size(); i++)
{
auto kbatch_curr = kbatch_list[i];
auto argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
kbatch_curr);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
if(do_verification)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(
std::cout << "c_host : ", c_m_n_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_device: ", c_m_n_device_result.mData, ",")
<< std::endl;
}
}
std::string op_name = op_ptr->GetTypeString();
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch "
<< kbatch_curr << std::endl;
#if defined CK_ENABLE_FP8
// set softer tolerances for fp8
if constexpr(is_same_v<ADataType, f8_t> || is_same_v<BDataType, f8_t> ||
is_same_v<CDataType, f8_t>)
{
std::string msg = "Error: Incorrect results!";
double rtol = 1e-1;
double atol = 1e-1;
pass = pass & ck::utils::check_err(
c_m_n_device_result, c_m_n_host_result, msg, rtol, atol);
}
else
{
#endif
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
#if defined CK_ENABLE_FP8
}
#endif
if(tflops > best_tflops)
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
}
}
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem"
<< std::endl;
}
}
}
if constexpr(is_same<CDataType, float>::value)
{
std::cout << "Best Perf for datatype = f32";
}
else if constexpr(is_same<CDataType, half_t>::value)
{
std::cout << "Best Perf for datatype = f16";
}
else if constexpr(is_same<CDataType, bhalf_t>::value)
{
std::cout << "Best Perf for datatype = bf16";
}
else if constexpr(is_same<CDataType, int8_t>::value)
{
std::cout << "Best Perf for datatype = int8";
}
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value)
{
std::cout << " ALayout = RowMajor";
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value)
{
std::cout << " ALayout = ColumnMajor";
}
if constexpr(is_same<BLayout, tensor_layout::gemm::RowMajor>::value)
{
std::cout << " BLayout = RowMajor";
}
else if constexpr(is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value)
{
std::cout << " BLayout = ColumnMajor";
}
std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << best_kbatch
<< " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec
<< " GB/s, " << best_op_name << std::endl;
return pass;
}
} // namespace profiler
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_gemv_splitk_impl.hpp"
#include "profiler_operation_registry.hpp"
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
KM_KN_MN, // 2
KM_NK_MN, // 3
};
enum struct GemmDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3
F8_F16_F16, // 4
F16_F8_F16, // 5
};
#define OP_NAME "gemv_splitk"
#define OP_DESC "Split-K GEMM"
int profile_gemv_splitk(int argc, char* argv[])
{
if(argc != 15)
{
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8)\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
printf("arg4: verification (0: no; 1: yes)\n");
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=no, 1=yes)\n");
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
printf("arg14: split k into mulitiple batch\n");
exit(1);
}
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]);
const int M = std::stoi(argv[8]);
const int N = std::stoi(argv[9]);
const int K = std::stoi(argv[10]);
const int StrideA = std::stoi(argv[11]);
const int StrideB = std::stoi(argv[12]);
const int StrideC = std::stoi(argv[13]);
const int KBatch = std::stoi(argv[14]);
using F32 = float;
using F16 = ck::half_t;
// #if defined CK_ENABLE_FP8
// using F8 = ck::f8_t;
// #endif
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
auto profile = [&](auto a_type,
auto b_type,
auto acc_type,
auto c_type,
auto a_layout,
auto b_layout,
auto c_layout) {
using ADataType = decltype(a_type);
using BDataType = decltype(b_type);
using AccDataType = decltype(acc_type);
using CDataType = decltype(c_type);
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout);
using CLayout = decltype(c_layout);
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
bool pass = ck::profiler::profile_gemv_splitk_impl<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? DefaultStrideA : StrideA,
(StrideB < 0) ? DefaultStrideB : StrideB,
(StrideC < 0) ? DefaultStrideC : StrideC,
KBatch);
return pass ? 0 : 1;
};
// if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
// {
// return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{});
// }
// else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
// {
// return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{});
// }
// else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
// {
// return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{});
// }
// else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
// {
// return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{});
// }
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{});
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
}
// else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
// {
// return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{});
// }
// else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
// {
// return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{});
// }
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
}
}
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemv_splitk);
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_tall_and_skinny_gemm_splitk_impl.hpp"
#include "profiler_operation_registry.hpp"
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
KM_KN_MN, // 2
KM_NK_MN, // 3
};
enum struct GemmDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3
F8_F16_F16, // 4
F16_F8_F16, // 5
};
#define OP_NAME "tall_and_skinny_gemm_splitk"
#define OP_DESC "Tall and Skinny GEMM splitk"
int profile_tall_and_skinny_gemm_splitk(int argc, char* argv[])
{
if(argc != 15)
{
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8)\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
printf("arg4: verification (0: no; 1: yes)\n");
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=no, 1=yes)\n");
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
printf("arg14: split k into mulitiple batch\n");
exit(1);
}
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]);
const int M = std::stoi(argv[8]);
const int N = std::stoi(argv[9]);
const int K = std::stoi(argv[10]);
const int StrideA = std::stoi(argv[11]);
const int StrideB = std::stoi(argv[12]);
const int StrideC = std::stoi(argv[13]);
const int KBatch = std::stoi(argv[14]);
using F32 = float;
using F16 = ck::half_t;
// #if defined CK_ENABLE_FP8
// using F8 = ck::f8_t;
// #endif
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
auto profile = [&](auto a_type,
auto b_type,
auto acc_type,
auto c_type,
auto a_layout,
auto b_layout,
auto c_layout) {
using ADataType = decltype(a_type);
using BDataType = decltype(b_type);
using AccDataType = decltype(acc_type);
using CDataType = decltype(c_type);
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout);
using CLayout = decltype(c_layout);
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
bool pass = ck::profiler::profile_tall_and_skinny_gemm_splitk_impl<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? DefaultStrideA : StrideA,
(StrideB < 0) ? DefaultStrideB : StrideB,
(StrideC < 0) ? DefaultStrideC : StrideC,
KBatch);
return pass ? 0 : 1;
};
// if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
// {
// return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{});
// }
// else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
// {
// return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{});
// }
// else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
// {
// return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{});
// }
// else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
// {
// return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{});
// }
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{});
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
}
// else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
// {
// return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{});
// }
// else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
// {
// return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{});
// }
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
}
}
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_tall_and_skinny_gemm_splitk);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment