Commit 436fed88 authored by root's avatar root
Browse files

Merge branch 'add_contraction_example_fp64' of...

Merge branch 'add_contraction_example_fp64' of https://github.com/ROCmSoftwarePlatform/composable_kernel into add_contraction_example_fp64
parents b7403bf4 606fb7f6
......@@ -26,9 +26,9 @@ template <index_t NumDimG,
typename Acc1BiasDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation,
typename C0DEElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
typename C1DEElementwiseOperation,
MaskingSpecialization MaskingSpec>
struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator
{
......@@ -58,9 +58,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op,
C0DEElementwiseOperation c0de_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) = 0;
C1DEElementwiseOperation c1de_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
......
......@@ -533,6 +533,11 @@ struct DeviceElementwiseNormalizationImpl
return (false);
}
if(p_arg_->x_lds_size_ >= 65536)
{
return (false);
}
return true;
};
......
......@@ -669,6 +669,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
{
throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting");
}
if(arg.p_workspace_e_grid_ == nullptr || arg.p_workspace_mean_ == nullptr ||
arg.p_workspace_var_ == nullptr || arg.p_workspace_count_ == nullptr)
throw std::runtime_error("wrong! WorkSpace pointer has not been set");
index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.gemm_e_grid_desc_m_n_);
......@@ -939,7 +942,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
}
}
return true;
return GridwiseGemmWelford::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.gemm_e_grid_desc_m_n_,
arg.block_2_etile_map_);
}
// polymorphic
......@@ -1055,7 +1062,12 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
<< GemmKPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< getGemmSpecializationString(GemmSpec)
<< getGemmSpecializationString(GemmSpec) << ", "
<< PostShuffleThreadClusterSize_M_N::At(I0) << ", "
<< PostShuffleThreadClusterSize_M_N::At(I1) << ", "
<< LayernormThreadClusterSize_M_N::At(I0) << ", "
<< LayernormThreadClusterSize_M_N::At(I1) << ", "
<< LayernormThreadSliceSize_M
<< ">"
<< " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
......
......@@ -49,6 +49,14 @@ struct Add
y = x0 + x1;
};
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const bhalf_t& x1) const
{
const float x1_tmp = ck::type_convert<float>(x1);
y = x0 + x1_tmp;
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
......@@ -67,6 +75,30 @@ struct Add
};
};
struct ScaleAdd
{
__host__ __device__ ScaleAdd(float scale) : scale_(scale) {}
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
template <>
__host__ __device__ void
operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
{
y = scale_ * x0 + ck::type_convert<float>(x1);
};
template <>
__host__ __device__ void
operator()<float, float, bhalf_t>(float& y, const float& x0, const bhalf_t& x1) const
{
y = scale_ * x0 + ck::type_convert<float>(x1);
};
float scale_;
};
struct Subtract
{
template <typename T>
......
......@@ -158,9 +158,9 @@ static inline __device__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00;
};
static inline __device__ float sqrt(float x) { return ::sqrtf(x); };
static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
static inline __device__ double sqrt(double x) { return ::sqrt(x); };
static inline __device__ double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
} // namespace math
} // namespace ck
......@@ -89,8 +89,10 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
template <typename Activation>
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<BF16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances);
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<BF16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
template <typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename Acc0BiasDataType,
MaskingSpecialization MaskingSpec>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpec>>
{
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpec>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t> &&
Acc0BiasDataType::Size() == 1 &&
is_same_v<tuple_element_t<0, Acc0BiasDataType>, half_t>)
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
{
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
else if(MaskingSpec == MaskingSpecialization::MaskDisabled)
{
add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> &&
is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16> &&
Acc0BiasDataType::Size() == 1 &&
is_same_v<tuple_element_t<0, Acc0BiasDataType>, BF16>)
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
{
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
else if(MaskingSpec == MaskingSpecialization::MaskDisabled)
{
add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -10,7 +10,6 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Row,
Row,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
F16,
F16,
PassThrough,
PassThrough,
AddReluAdd,
PassThrough>>>&);
void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Row,
Col,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
F16,
F16,
PassThrough,
PassThrough,
AddReluAdd,
PassThrough>>>&);
void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Col,
Row,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
F16,
F16,
PassThrough,
PassThrough,
AddReluAdd,
PassThrough>>>&);
void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDLayernorm<Col,
Col,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
F16,
F16,
PassThrough,
PassThrough,
AddReluAdd,
PassThrough>>>&);
// GEMM + Add + Relu + Add + Layernorm
template <typename ALayout,
typename BLayout,
typename D0Layout,
typename D1Layout,
typename HLayout,
typename ADataType,
typename BDataType,
typename D0DataType,
typename D1DataType,
typename GammaDataType,
typename BetaDataType,
typename HDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleDLayernorm<
ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
HLayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
GammaDataType,
BetaDataType,
HDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::AddReluAdd,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceGemmMultipleDLayernorm<ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
HLayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
GammaDataType,
BetaDataType,
HDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::AddReluAdd,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<D0DataType, half_t> && is_same_v<D1DataType, half_t> &&
is_same_v<GammaDataType, half_t> && is_same_v<BetaDataType, half_t> &&
is_same_v<HDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<HLayout, Row>)
{
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<HLayout, Row>)
{
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<HLayout, Row>)
{
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
is_same_v<HLayout, Row>)
{
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances(
op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
add_instance_library(device_batched_gemm_softmax_gemm_permute_instance
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
MaskingSpecialization MaskingSpec>
using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances =
std::tuple<
// clang-format off
// #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec|
// #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| |
// #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| |
// #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>,
// Padded fallback kernel
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple<BF16>, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>
// clang-format on
>;
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<BF16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances)
{
add_device_operation_instances(
instances,
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances<
2,
1,
1,
1,
1,
MaskingSpecialization::MaskOutUpperTriangle>{});
}
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
BF16,
BF16,
BF16,
BF16,
ck::Tuple<BF16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances)
{
add_device_operation_instances(
instances,
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances<
2,
1,
1,
1,
1,
MaskingSpecialization::MaskDisabled>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO,
MaskingSpecialization MaskingSpec>
using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances =
std::tuple<
// clang-format off
// #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec|
// #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| |
// #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| |
// #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>,
// Padded fallback kernel
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>
// clang-format on
>;
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>&
instances)
{
add_device_operation_instances(
instances,
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<
2,
1,
1,
1,
1,
MaskingSpecialization::MaskOutUpperTriangle>{});
}
void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
F16,
F16,
F16,
F16,
ck::Tuple<F16>,
ck::Tuple<>,
PassThrough,
PassThrough,
ScaleAdd,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances)
{
add_device_operation_instances(
instances,
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<
2,
1,
1,
1,
1,
MaskingSpecialization::MaskDisabled>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -23,6 +23,11 @@ template <typename XElementwise, typename YElementwise, index_t Rank, index_t Re
using device_elementwise_normalization_f16_instances =
std::tuple <
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize>
DeviceElementwiseNormalizationImpl<ck::Tuple<F16, F16>, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 1, 1, 1, 1, 1, 1>, // fallback kernel for large N
DeviceElementwiseNormalizationImpl<ck::Tuple<F16, F16>, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 2, 1, 2, 1, 2, 2>, // fallback kernel for large N
DeviceElementwiseNormalizationImpl<ck::Tuple<F16, F16>, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8>, // fallback kernel for large N
DeviceElementwiseNormalizationImpl<ck::Tuple<F16, F16>, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 2, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8>, // fallback kernel for large N
DeviceElementwiseNormalizationImpl<ck::Tuple<F16, F16>, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 4, 64, 1, 8, 1, 1, 1, 1, 1, 1, 1>, // fallback kernel for large N
DeviceElementwiseNormalizationImpl<ck::Tuple<F16, F16>, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1, 1, 1, 1, 1>, // fallback kernel
DeviceElementwiseNormalizationImpl<ck::Tuple<F16, F16>, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 2, 1, 2, 1, 2, 2>, // fallback kernel
DeviceElementwiseNormalizationImpl<ck::Tuple<F16, F16>, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 4, 1, 4, 1, 4, 4>, // fallback kernel
......
add_instance_library(device_gemm_add_relu_add_layernorm_instance
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment