Commit b89a88b5 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge branch 'develop' into wavelet_model

parents 41d5fca7 43c898f6
......@@ -254,23 +254,28 @@ struct Tensor
Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpaceSize()) {}
template <typename OutT>
Tensor<OutT> CopyAsType()
Tensor<OutT> CopyAsType() const
{
Tensor<OutT> ret(mDesc);
for(size_t i = 0; i < mData.size(); i++)
{
ret.mData[i] = static_cast<OutT>(mData[i]);
ret.mData[i] = ck::type_convert<OutT>(mData[i]);
}
return ret;
}
Tensor(const Tensor& other) : mDesc(other.mDesc), mData(other.mData) {}
Tensor() = delete;
Tensor(const Tensor&) = default;
Tensor(Tensor&&) = default;
Tensor& operator=(const Tensor& other)
~Tensor() = default;
Tensor& operator=(const Tensor&) = default;
Tensor& operator=(Tensor&&) = default;
template <typename FromT>
explicit Tensor(const Tensor<FromT>& other) : Tensor(other.template CopyAsType<T>())
{
mDesc = other.mDesc;
mData = other.mData;
return *this;
}
const std::vector<std::size_t>& GetLengths() const { return mDesc.GetLengths(); }
......
......@@ -5,6 +5,7 @@
#include <cmath>
#include <numeric>
#include <random>
#include "ck/ck.hpp"
......@@ -126,6 +127,23 @@ struct GeneratorTensor_3<ck::bhalf_t>
}
};
template <typename T>
struct GeneratorTensor_4
{
std::default_random_engine generator;
std::normal_distribution<float> distribution;
GeneratorTensor_4(float mean, float stddev) : generator(1), distribution(mean, stddev){};
template <typename... Is>
T operator()(Is...)
{
float tmp = distribution(generator);
return ck::type_convert<T>(tmp);
}
};
struct GeneratorTensor_Checkboard
{
template <typename... Ts>
......
......@@ -16,6 +16,7 @@ add_subdirectory(batched_gemm)
add_subdirectory(batched_gemm_reduce)
add_subdirectory(batched_gemm_gemm)
add_subdirectory(batched_gemm_softmax_gemm)
add_subdirectory(batched_gemm_add_relu_gemm_add)
add_subdirectory(grouped_gemm)
add_subdirectory(contraction_scale)
add_subdirectory(contraction_bilinear)
......@@ -42,6 +43,7 @@ add_library(device_operations STATIC
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
$<TARGET_OBJECTS:device_gemm_bias_add_reduce_instance>
$<TARGET_OBJECTS:device_batched_gemm_instance>
$<TARGET_OBJECTS:device_batched_gemm_add_relu_gemm_add_instance>
$<TARGET_OBJECTS:device_batched_gemm_reduce_instance>
$<TARGET_OBJECTS:device_grouped_gemm_instance>
$<TARGET_OBJECTS:device_contraction_scale_instance>
......@@ -78,6 +80,7 @@ target_include_directories(device_operations PUBLIC
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/problem_transform>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/device>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/device/impl>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/grid>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/block>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/tensor_operation/gpu/warp>
......
add_instance_library(device_batched_gemm_add_relu_gemm_add_instance
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances =
std::tuple<
// clang-format off
//##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K|NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1|A0K1|B0K1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer|
//##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per|Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
// no padding
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>,
// Padded fallback kernel
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>
// clang-format on
>;
void add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmMultipleDGemmMultipleD<Row,
Col,
ck::Tuple<Row>,
Row,
ck::Tuple<Row>,
Row,
F16,
F16,
ck::Tuple<F16>,
F16,
ck::Tuple<F16>,
F16,
PassThrough,
PassThrough,
CDE0ElementOp,
PassThrough,
CDE1ElementOp>>>& instances)
{
add_device_operation_instances(
instances,
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances =
std::tuple<
// clang-format off
//##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K| NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1| A0K1| B0K1|B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer|
//##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
// no padding
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 256, 128, 32, 128, 32, 8, 8, 4, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>,
// Padded fallback kernel
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 4, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>
// clang-format on
>;
void add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmMultipleDGemmMultipleD<Row,
Col,
ck::Tuple<Row>,
Col,
ck::Tuple<Row>,
Row,
F16,
F16,
ck::Tuple<F16>,
F16,
ck::Tuple<F16>,
F16,
PassThrough,
PassThrough,
CDE0ElementOp,
PassThrough,
CDE1ElementOp>>>& instances)
{
add_device_operation_instances(
instances,
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
add_instance_library(device_batched_gemm_gemm_instance
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
)
......@@ -26,6 +26,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances = std::tuple<
......@@ -34,10 +35,22 @@ using device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_inst
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, // failed validation on MI100
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, // failed validation on MI100
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, // failed validation on MI100
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, // failed validation on MI100
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>,
// Padded fallback kernel
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>
// clang-format on
>;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances = std::tuple<
// clang-format off
//################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 4, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, // TODO: to enable; can trigger compiler crash in mainline #9110 but not in #10738
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 4, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 4, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, // TODO: to enable; can cause validation error on MI100
// DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 4, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>, // TODO: to enable; can cause validation error on MI100
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>,
// Padded fallback kernel
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Col, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 64, 32, 128, 32, 8, 8, 4, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>
// clang-format on
>;
void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
Col,
Col,
Row,
F16,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -26,6 +26,8 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmPadded =
ck::tensor_operation::device::GemmSpecialization::MNOPadding; // Padding K is currently flawed
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances =
......@@ -35,10 +37,21 @@ using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_
//#######################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#######################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>,
// Padded fallback kernel
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>
// clang-format on
>;
......
......@@ -6,7 +6,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
......@@ -27,19 +27,17 @@ using outputType = F16;
using Normalize = ck::tensor_operation::element_wise::Normalize;
using device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances = std::tuple<
// clang-format off
//###################|in | mean| square_mean| gamma| beta| out| ComputeDataType| functor| NDim| MPerThread| in, mean, square_mean, gamma, beta, out ScalarPerVector|
//###################|in | mean| square_mean| gamma| beta| out| ComputeDataType| functor| NDim| MPerThread| in, mean, square_mean, gamma, beta, out ScalarPerVector|
//###################|in | mean| square_mean| gamma| beta| out| ComputeDataType| functor| NDim| MPerThread| in, mean, square_mean, gamma, beta, out ScalarPerVector|
//###################|in | mean| square_mean| gamma| beta| out| ComputeDataType| functor| NDim| MPerThread| in, mean, square_mean, gamma, beta, out ScalarPerVector|
Device5AryElementwise<F16, F32, F32, F16, F16, F16, F32, Normalize, 2, 8, 8, 1, 1, 8, 8, 8 >,
Device5AryElementwise<F16, F32, F32, F16, F16, F16, F32, Normalize, 2, 4, 4, 1, 1, 4, 4, 4 >,
Device5AryElementwise<F16, F32, F32, F16, F16, F16, F32, Normalize, 2, 2, 2, 1, 1, 2, 2, 2 >,
Device5AryElementwise<F16, F32, F32, F16, F16, F16, F32, Normalize, 2, 1, 1, 1, 1, 1, 1, 1 >
//###################|<in, mean, square_mean, gamma, beta>| <out>| functor| NDim| MPerThread| <in, mean, square_mean, gamma, beta ScalarPerVector>| <out ScalarPerVector>|
DeviceElementwise<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, 2, 8, Sequence<8, 1, 1, 8, 8>, Sequence<8> >,
DeviceElementwise<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, 2, 4, Sequence<4, 1, 1, 4, 4>, Sequence<4> >,
DeviceElementwise<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, 2, 2, Sequence<2, 1, 1, 2, 2>, Sequence<2> >,
DeviceElementwise<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, 2, 1, Sequence<1, 1, 1, 1, 1>, Sequence<1> >
// clang-format on
>;
void add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances(
std::vector<DeviceElementwisePtr<5, 1, 2, Normalize>>& instances)
std::vector<DeviceElementwiseBasePtr<Tuple<F16, F32, F32, F16, F16>, Tuple<F16>, Normalize, 2>>&
instances)
{
add_device_operation_instances(
instances, device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances{});
......
......@@ -6,7 +6,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
......@@ -31,18 +31,18 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances = std::tuple<
// clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 8>, 4>
// clang-format on
>;
......
......@@ -6,7 +6,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
......@@ -31,18 +31,18 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances = std::tuple<
// clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>
// clang-format on
>;
......
......@@ -6,7 +6,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
......@@ -26,28 +26,23 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances = std::tuple<
// clang-format off
//###################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM|Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//###################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//###################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 96, 128, 4, 8, 16, 16, 3, 4, S<1, 4, 32, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 32, 256, 4, 4, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 16, 256, 4, 4, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 16, 128, 4, 4, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 1, 1, S<1, 32, 1, 8>, 4>
// clang-format on
>;
......
......@@ -6,7 +6,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
......@@ -31,23 +31,23 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances = std::tuple<
// clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>
//#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 4, 4, true, 1, 1, S<1, 16, 1, 4>, 4>
// clang-format on
>;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/utility/data_type.hpp"
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
namespace {
using F16 = ck::half_t;
using F32 = float;
using Pass = ck::tensor_operation::element_wise::PassThrough;
} // namespace
template <index_t Rank, index_t Reduce>
using device_softmax_f16_f16_instances = std::tuple<
// clang-format off
// InDataType, AccDataType, OutDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize>
DeviceSoftmax<F16, F32, F16, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1>, // fallback kernel
DeviceSoftmax<F16, F32, F16, Rank, Reduce, 256, 8, 32, 1, 8, 1, 8, 8>,
DeviceSoftmax<F16, F32, F16, Rank, Reduce, 256, 4, 64, 1, 8, 1, 8, 8>,
DeviceSoftmax<F16, F32, F16, Rank, Reduce, 256, 2, 128, 1, 8, 1, 8, 8>,
DeviceSoftmax<F16, F32, F16, Rank, Reduce, 256, 2, 128, 1, 16, 1, 8, 8>,
DeviceSoftmax<F16, F32, F16, Rank, Reduce, 256, 2, 128, 1, 32, 1, 8, 8>,
DeviceSoftmax<F16, F32, F16, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 8>,
DeviceSoftmax<F16, F32, F16, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 8>,
DeviceSoftmax<F16, F32, F16, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 8>
// InDataType, AccDataType, OutDataType, InElementwiseOp, AccElementwiseOp, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize>
DeviceSoftmaxImpl< F16, F32, F16, Pass, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1>, // fallback kernel
DeviceSoftmaxImpl< F16, F32, F16, Pass, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 8, 8>,
DeviceSoftmaxImpl< F16, F32, F16, Pass, Pass, Rank, Reduce, 256, 4, 64, 1, 8, 1, 8, 8>,
DeviceSoftmaxImpl< F16, F32, F16, Pass, Pass, Rank, Reduce, 256, 2, 128, 1, 8, 1, 8, 8>,
DeviceSoftmaxImpl< F16, F32, F16, Pass, Pass, Rank, Reduce, 256, 2, 128, 1, 16, 1, 8, 8>,
DeviceSoftmaxImpl< F16, F32, F16, Pass, Pass, Rank, Reduce, 256, 2, 128, 1, 32, 1, 8, 8>,
DeviceSoftmaxImpl< F16, F32, F16, Pass, Pass, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 8>,
DeviceSoftmaxImpl< F16, F32, F16, Pass, Pass, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 8>,
DeviceSoftmaxImpl< F16, F32, F16, Pass, Pass, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 8>
// clang-format on
>;
void add_device_softmax_f16_f16_rank3_instances(std::vector<DeviceNormalizationPtr>& instances)
void add_device_softmax_f16_f16_rank3_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, Pass, Pass, 3>>& instances)
{
add_device_operation_instances(instances, device_softmax_f16_f16_instances<3, 1>{});
add_device_operation_instances(instances, device_softmax_f16_f16_instances<3, 2>{});
}
void add_device_softmax_f16_f16_rank4_instances(std::vector<DeviceNormalizationPtr>& instances)
void add_device_softmax_f16_f16_rank4_instances(
std::vector<DeviceSoftmaxPtr<F16, F32, F16, Pass, Pass, 4>>& instances)
{
add_device_operation_instances(instances, device_softmax_f16_f16_instances<4, 1>{});
add_device_operation_instances(instances, device_softmax_f16_f16_instances<4, 2>{});
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F32 = float;
namespace {
using F32 = float;
using Pass = ck::tensor_operation::element_wise::PassThrough;
} // namespace
template <index_t Rank, index_t Reduce>
using device_softmax_f32_f32_instances = std::tuple<
// clang-format off
// InDataType, AccDataType, OutDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize>
DeviceSoftmax<F32, F32, F32, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1>, // fallback kernel
DeviceSoftmax<F32, F32, F32, Rank, Reduce, 256, 8, 32, 1, 8, 1, 4, 4>,
DeviceSoftmax<F32, F32, F32, Rank, Reduce, 256, 4, 64, 1, 8, 1, 4, 4>,
DeviceSoftmax<F32, F32, F32, Rank, Reduce, 256, 2, 128, 1, 8, 1, 4, 4>,
DeviceSoftmax<F32, F32, F32, Rank, Reduce, 256, 2, 128, 1, 16, 1, 4, 4>,
DeviceSoftmax<F32, F32, F32, Rank, Reduce, 256, 2, 128, 1, 32, 1, 4, 4>,
DeviceSoftmax<F32, F32, F32, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 4>,
DeviceSoftmax<F32, F32, F32, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 4>,
DeviceSoftmax<F32, F32, F32, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 4>
// InDataType, AccDataType, OutDataType, InElementwiseOp, AccElementwiseOp, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize>
DeviceSoftmaxImpl< F32, F32, F32, Pass, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1>, // fallback kernel
DeviceSoftmaxImpl< F32, F32, F32, Pass, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 4, 4>,
DeviceSoftmaxImpl< F32, F32, F32, Pass, Pass, Rank, Reduce, 256, 4, 64, 1, 8, 1, 4, 4>,
DeviceSoftmaxImpl< F32, F32, F32, Pass, Pass, Rank, Reduce, 256, 2, 128, 1, 8, 1, 4, 4>,
DeviceSoftmaxImpl< F32, F32, F32, Pass, Pass, Rank, Reduce, 256, 2, 128, 1, 16, 1, 4, 4>,
DeviceSoftmaxImpl< F32, F32, F32, Pass, Pass, Rank, Reduce, 256, 2, 128, 1, 32, 1, 4, 4>,
DeviceSoftmaxImpl< F32, F32, F32, Pass, Pass, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 4>,
DeviceSoftmaxImpl< F32, F32, F32, Pass, Pass, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 4>,
DeviceSoftmaxImpl< F32, F32, F32, Pass, Pass, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 4>
// clang-format on
>;
void add_device_softmax_f32_f32_rank3_instances(std::vector<DeviceNormalizationPtr>& instances)
void add_device_softmax_f32_f32_rank3_instances(
std::vector<DeviceSoftmaxPtr<F32, F32, F32, Pass, Pass, 3>>& instances)
{
add_device_operation_instances(instances, device_softmax_f32_f32_instances<3, 1>{});
add_device_operation_instances(instances, device_softmax_f32_f32_instances<3, 2>{});
}
void add_device_softmax_f32_f32_rank4_instances(std::vector<DeviceNormalizationPtr>& instances)
void add_device_softmax_f32_f32_rank4_instances(
std::vector<DeviceSoftmaxPtr<F32, F32, F32, Pass, Pass, 4>>& instances)
{
add_device_operation_instances(instances, device_softmax_f32_f32_instances<4, 1>{});
add_device_operation_instances(instances, device_softmax_f32_f32_instances<4, 2>{});
......
......@@ -12,6 +12,8 @@ set(PROFILER_SOURCE
src/profile_gemm_add_add_fastgelu.cpp
src/profile_gemm_reduce.cpp
src/profile_batched_gemm.cpp
src/profile_batched_gemm_gemm.cpp
src/profile_batched_gemm_add_relu_gemm_add.cpp
src/profile_batched_gemm_reduce.cpp
src/profile_grouped_gemm.cpp
src/profile_conv_fwd.cpp
......@@ -35,6 +37,8 @@ target_link_libraries(ckProfiler PRIVATE device_gemm_add_add_fastgelu_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_add_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_add_relu_gemm_add_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
namespace ck {
namespace profiler {
template <typename A0Layout,
typename B0Layout,
typename D0sLayout,
typename B1Layout,
typename D1sLayout,
typename E1Layout,
typename A0DataType,
typename B0DataType,
typename D0sDataType,
typename B1DataType,
typename D1sDataType,
typename E1DataType>
bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification,
int init_method,
bool do_log,
bool time_kernel,
int M,
int N,
int K,
int O,
int BatchCount = 1,
int StrideA0 = -1,
int StrideB0 = -1,
int StrideD0 = -1,
int StrideB1 = -1,
int StrideD1 = -1,
int StrideE1 = -1,
int BatchStrideA0 = -1,
int BatchStrideB0 = -1,
int BatchStrideD0 = -1,
int BatchStrideB1 = -1,
int BatchStrideD1 = -1,
int BatchStrideE1 = -1)
{
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
using PassThrough = tensor_operation::element_wise::PassThrough;
using A0ElementOp = PassThrough;
using B0ElementOp = PassThrough;
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
using B1ElementOp = PassThrough;
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
using D0DataType = remove_cvref_t<tuple_element_t<0, D0sDataType>>;
using D0Layout = remove_cvref_t<tuple_element_t<0, D0sLayout>>;
using D1DataType = remove_cvref_t<tuple_element_t<0, D1sDataType>>;
using D1Layout = remove_cvref_t<tuple_element_t<0, D1sLayout>>;
// for reference
using RefAcc0DataType = float;
using RefAcc1DataType = float;
bool pass = true;
const int DefaultStrideA0 = ck::is_same_v<A0Layout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
const int DefaultStrideD0 = ck::is_same_v<D0Layout, Row> ? N : M;
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
const int DefaultStrideD1 = ck::is_same_v<D1Layout, Row> ? O : M;
const int DefaultStrideE1 = ck::is_same_v<E1Layout, Row> ? O : M;
StrideA0 = (StrideA0 < 0) ? DefaultStrideA0 : StrideA0;
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
StrideD0 = (StrideD0 < 0) ? DefaultStrideD0 : StrideD0;
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
StrideD1 = (StrideD1 < 0) ? DefaultStrideD1 : StrideD1;
StrideE1 = (StrideE1 < 0) ? DefaultStrideE1 : StrideE1;
const int DefaultBatchStrideA0 = (ck::is_same_v<A0Layout, Col> ? K : M) * StrideA0;
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
const int DefaultBatchStrideD0 = (ck::is_same_v<D0Layout, Col> ? N : M) * StrideD0;
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
const int DefaultBatchStrideD1 = (ck::is_same_v<D1Layout, Col> ? O : M) * StrideD1;
const int DefaultBatchStrideE1 = (ck::is_same_v<E1Layout, Col> ? O : M) * StrideE1;
BatchStrideA0 = BatchStrideA0 < 0 ? DefaultBatchStrideA0 : BatchStrideA0;
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
BatchStrideD0 = BatchStrideD0 < 0 ? DefaultBatchStrideD0 : BatchStrideD0;
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
BatchStrideD1 = BatchStrideD1 < 0 ? DefaultBatchStrideD1 : BatchStrideD1;
BatchStrideE1 = BatchStrideE1 < 0 ? DefaultBatchStrideE1 : BatchStrideE1;
auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if(std::is_same<decltype(layout), Row>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({batch_stride, 1, stride}));
}
};
// E_m_o = A_m_k * B0_k_n * B1_n_o
Tensor<A0DataType> a0_g_m_k(
f_host_tensor_descriptor(BatchCount, M, K, StrideA0, BatchStrideA0, A0Layout{}));
Tensor<B0DataType> b0_g_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
Tensor<D0DataType> d0_g_m_n(
f_host_tensor_descriptor(BatchCount, M, N, StrideD0, BatchStrideD0, D0Layout{}));
Tensor<B1DataType> b1_g_n_o(
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
Tensor<D1DataType> d1_g_m_o(
f_host_tensor_descriptor(BatchCount, M, O, StrideD1, BatchStrideD1, D1Layout{}));
Tensor<E1DataType> e1_g_m_o_host_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
Tensor<E1DataType> e1_g_m_o_device_result(
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
// Host verification: Output of Gemm0 is input A of Gemm1
Tensor<RefAcc0DataType> c0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<RefAcc0DataType> e0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
Tensor<RefAcc1DataType> c1_g_m_o(f_host_tensor_descriptor(BatchCount, M, O, O, M * O, Row{}));
std::cout << "a0_g_m_k: " << a0_g_m_k.mDesc << std::endl;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
std::cout << "d0_g_m_n: " << d0_g_m_n.mDesc << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
std::cout << "d1_g_m_o: " << d1_g_m_o.mDesc << std::endl;
std::cout << "e1_g_m_o: " << e1_g_m_o_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_g_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 3});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 3});
d0_g_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 3});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 3});
d1_g_m_o.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 3});
break;
default:
a0_g_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_g_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
d1_g_m_o.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
}
DeviceMem a0_g_m_k_device_buf(sizeof(A0DataType) * a0_g_m_k.mDesc.GetElementSize());
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize());
DeviceMem d0_g_m_n_device_buf(sizeof(D0DataType) * d0_g_m_n.mDesc.GetElementSpaceSize());
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize());
DeviceMem d1_g_m_o_device_buf(sizeof(D1DataType) * d1_g_m_o.mDesc.GetElementSpaceSize());
DeviceMem e1_g_m_o_device_buf(sizeof(E1DataType) *
e1_g_m_o_device_result.mDesc.GetElementSize());
a0_g_m_k_device_buf.ToDevice(a0_g_m_k.mData.data());
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
d0_g_m_n_device_buf.ToDevice(d0_g_m_n.mData.data());
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
d1_g_m_o_device_buf.ToDevice(d1_g_m_o.mData.data());
auto a0_element_op = A0ElementOp{};
auto b0_element_op = B0ElementOp{};
auto cde0_element_op = CDE0ElementOp{};
auto b1_element_op = B1ElementOp{};
auto cde1_element_op = CDE1ElementOp{};
using DeviceOp =
tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD<A0Layout,
B0Layout,
D0sLayout,
B1Layout,
D1sLayout,
E1Layout,
A0DataType,
B0DataType,
D0sDataType,
B1DataType,
D1sDataType,
E1DataType,
A0ElementOp,
B0ElementOp,
CDE0ElementOp,
B1ElementOp,
CDE1ElementOp>;
// get device op instances
const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
if(do_verification)
{
// Ref Gemm0
using ReferenceGemm0Instance = tensor_operation::host::ReferenceBatchedGemm<A0DataType,
B0DataType,
RefAcc0DataType,
RefAcc0DataType,
A0ElementOp,
B0ElementOp,
PassThrough>;
// Ref Gemm1
using ReferenceGemm1Instance = tensor_operation::host::ReferenceBatchedGemm<RefAcc0DataType,
B1DataType,
RefAcc1DataType,
RefAcc1DataType,
PassThrough,
B1ElementOp,
PassThrough>;
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a0_g_m_k, b0_g_k_n, c0_g_m_n, a0_element_op, b0_element_op, PassThrough{});
ref_gemm0_invoker.Run(ref_gemm0_argument);
// cde0_elementwise
e0_g_m_n.ForEach(
[&](auto&, auto idx) { cde0_element_op(e0_g_m_n(idx), c0_g_m_n(idx), d0_g_m_n(idx)); });
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
e0_g_m_n, b1_g_n_o, c1_g_m_o, PassThrough{}, b1_element_op, PassThrough{});
ref_gemm1_invoker.Run(ref_gemm1_argument);
// cde1_elementwise
e1_g_m_o_host_result.ForEach([&](auto&, auto idx) {
cde1_element_op(e1_g_m_o_host_result(idx), c1_g_m_o(idx), d1_g_m_o(idx));
});
}
std::string best_op_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device op instances
for(auto& op_ptr : op_ptrs)
{
auto argument_ptr = op_ptr->MakeArgumentPointer(
static_cast<A0DataType*>(a0_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
std::array<const void*, 1>{d0_g_m_n_device_buf.GetDeviceBuffer()},
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
std::array<const void*, 1>{d1_g_m_o_device_buf.GetDeviceBuffer()},
static_cast<E1DataType*>(e1_g_m_o_device_buf.GetDeviceBuffer()),
M,
N,
K,
O,
BatchCount,
StrideA0,
StrideB0,
std::array<ck::index_t, 1>{StrideD0},
StrideB1,
std::array<ck::index_t, 1>{StrideD1},
StrideE1,
BatchStrideA0,
BatchStrideB0,
std::array<ck::index_t, 1>{BatchStrideD0},
BatchStrideB1,
std::array<ck::index_t, 1>{BatchStrideD1},
BatchStrideE1,
a0_element_op,
b0_element_op,
cde0_element_op,
b1_element_op,
cde1_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
std::string op_name = op_ptr->GetTypeString();
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype =
(sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(D0DataType) * N +
sizeof(B1DataType) * N * O + sizeof(E1DataType) * M * O + sizeof(D1DataType) * O) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data());
pass = pass & ck::utils::check_err(e1_g_m_o_device_result.mData,
e1_g_m_o_host_result.mData);
if(do_log)
{
LogRangeAsType<float>(
std::cout << "e1_g_m_o_host_result : ", e1_g_m_o_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "e1_g_m_o_device_result : ", e1_g_m_o_device_result.mData, ",")
<< std::endl;
}
}
}
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass;
}
} // namespace profiler
} // namespace ck
......@@ -195,6 +195,12 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// early fail when no instances are found
if(op_ptrs.size() == 0)
{
return false;
}
if(do_verification)
{
auto ref_gemm0 = ReferenceGemm0Instance{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment