Unverified Commit 6eb55499 authored by rocking5566's avatar rocking5566 Committed by GitHub
Browse files

Gemm + bias + relu + add + layernorm (#272)

* Copy "gemm reduce" to "gemm bias add reduce"

* Implement gemm bias add reduction

* Fix compiler error due to merge from develop

* Add tensor operation for gemm + bias + add + reduce

* Add gemm_bais_add_reduce to ckProfiler

* Add c1 functor

* Refine type

* Use reduceAccDataType instead of explicitly float

* Change to use check_err()

* Do relu in float32 instead of bhalf_t. Because bhalf_t is unsigned

* Refactor relu. using type_trait instead of overloading

* Rename DxsReduceAccElementwiseOperation to DxsReduceAccElementwiseOperation

* Fix denominator

* Refine nameing

* Fix denominator  in host

* Remove useless include header

* Use AccDataType

* Fix static_cast order

* Refine type

* [What] Remove tuple type in the base class
[Why] External api depend on base class. if base class has relationship with type, we will need many class for different type
parent 561ec12f
......@@ -51,8 +51,8 @@ using UnaryDivElementOp =
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, true>;
using UnarySquareElementOp =
ck::tensor_operation::element_wise::UnarySquare<ReduceAccDataType, ReduceAccDataType, false>;
using DxsInElementOp = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOp = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
using DGlobalMemOp =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
......@@ -67,7 +67,7 @@ using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOp, DxsOutElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
......@@ -204,8 +204,8 @@ int main(int argc, char* argv[])
auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()));
auto dxs_in_element_op = DxsInElementOp{};
auto dxs_out_element_op = DxsOutElementOp{M, M};
auto dxs_in_element_op = DxsInElementOps{};
auto dxs_out_element_op = DxsOutElementOps{N, N};
// do GEMM
auto gemm = DeviceGemmReduceInstance{};
......@@ -261,14 +261,15 @@ int main(int argc, char* argv[])
for(int m = 0; m < M; ++m)
{
float d0_acc = d0_reduce_op.GetIdentityValue();
float d1_acc = d1_reduce_op.GetIdentityValue();
ReduceAccDataType d0_acc = d0_reduce_op.GetIdentityValue();
ReduceAccDataType d1_acc = d1_reduce_op.GetIdentityValue();
for(int n = 0; n < N; ++n)
{
float c_val = ck::type_convert<float>(c_m_n_host_result(m, n));
float d0_val = 0;
float d1_val = 0;
ReduceAccDataType c_val =
ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
ReduceAccDataType d0_val = 0;
ReduceAccDataType d1_val = 0;
dxs_in_element_op(ck::Number<0>{})(d0_val, c_val);
dxs_in_element_op(ck::Number<1>{})(d1_val, c_val);
......
......@@ -47,8 +47,8 @@ using UnaryIdenticElementOp =
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, false>;
using UnarySquareElementOp =
ck::tensor_operation::element_wise::UnarySquare<ReduceAccDataType, ReduceAccDataType, false>;
using DxsInElementOp = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOp = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>;
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOps = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>;
using DGlobalMemOp =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
......@@ -63,7 +63,7 @@ using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatc
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOp, DxsOutElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
......@@ -206,8 +206,8 @@ int main(int argc, char* argv[])
a_element_op,
b_element_op,
c_element_op,
DxsInElementOp{},
DxsOutElementOp{},
DxsInElementOps{},
DxsOutElementOps{},
BatchCount);
if(!batched_gemm.IsSupportedArgument(argument))
......
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_fp16 gemm_bias_relu_add_layernorm_xdl_fp16.cpp)
add_example_executable(example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp)
......@@ -2,7 +2,6 @@
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include "check_err.hpp"
#include "config.hpp"
......@@ -54,8 +53,8 @@ using UnaryDivElementOp =
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, true>;
using UnarySquareElementOp =
ck::tensor_operation::element_wise::UnarySquare<ReduceAccDataType, ReduceAccDataType, false>;
using DxsInElementOp = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOp = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
using DxsGlobalMemOp =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
......@@ -70,7 +69,7 @@ using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOp, DxsOutElementOp, DxsGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DxsGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
......@@ -143,7 +142,7 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<DDataType> mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<DDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
auto averageOpInst = UnaryDivElementOp{M};
auto averageOpInst = UnaryDivElementOp{N};
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
......@@ -162,7 +161,7 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
for(int n = 0; n < N; ++n)
{
ReduceAccDataType c_val = ck::type_convert<float>(c_m_n(m, n));
ReduceAccDataType c_val = ck::type_convert<ReduceAccDataType>(c_m_n(m, n));
ReduceAccDataType square_c_val = 0;
UnarySquareElementOp{}(square_c_val, c_val);
......@@ -267,8 +266,8 @@ int main()
ck::make_tuple(static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer()));
auto dxs_in_element_op = DxsInElementOp{};
auto dxs_out_element_op = DxsOutElementOp{M, M};
auto dxs_in_element_op = DxsInElementOps{};
auto dxs_out_element_op = DxsOutElementOps{N, N};
// Prepare GEMM, reduce_mean, reduce_mean_square
auto gemmReduce = DeviceGemmReduceInstance{};
......
......@@ -3,7 +3,6 @@
#include <sstream>
#include "device.hpp"
#include "device_base.hpp"
#include "common_header.hpp"
#include "gridwise_5ary_Elementwise_1d.hpp"
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
......
......@@ -22,7 +22,7 @@ template <typename GridwiseGemm,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation,
typename DxsReduceAccElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -44,7 +44,7 @@ __global__ void
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const DxsInElementwiseOperation dxs_in_element_op,
const DxsAccElementwiseOperation dxs_out_element_op,
const DxsReduceAccElementwiseOperation dxs_out_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -126,7 +126,7 @@ template <typename ALayout,
typename CElementwiseOperation,
typename DxsReduceOperation,
typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation,
typename DxsReduceAccElementwiseOperation,
typename DGlobalMemoryDataOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
......@@ -162,12 +162,12 @@ template <typename ALayout,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
AElementwiseOperation,
struct DeviceBatchedGemmReduce_Xdl_CShuffle
: public DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsAccElementwiseOperation>
DxsReduceAccElementwiseOperation>
{
using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle;
......@@ -527,7 +527,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
CElementwiseOperation,
DxsReduceOperation,
DxsInElementwiseOperation,
DxsAccElementwiseOperation,
DxsReduceAccElementwiseOperation,
InMemoryDataOperationEnum::Set,
DGlobalMemoryDataOperation,
AGridDesc_AK0_M_AK1,
......@@ -587,7 +587,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
index_t BatchCount)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
......@@ -645,7 +645,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
DxsInElementwiseOperation dxs_in_element_op_;
DxsAccElementwiseOperation dxs_out_element_op_;
DxsReduceAccElementwiseOperation dxs_out_element_op_;
};
// Invoker
......@@ -703,7 +703,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsAccElementwiseOperation,
DxsReduceAccElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -746,7 +746,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsAccElementwiseOperation,
DxsReduceAccElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -832,7 +832,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
index_t BatchCount)
{
return Argument{p_a,
......@@ -856,10 +856,11 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
DPtrsGlobal p_dxs,
void* p_dxs,
index_t MRaw,
index_t NRaw,
index_t KRaw,
......@@ -870,13 +871,14 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
index_t BatchCount) override
{
DPtrsGlobal dxs_tuple = *(static_cast<DPtrsGlobal*>(p_dxs));
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
p_dxs,
dxs_tuple,
MRaw,
NRaw,
KRaw,
......
......@@ -6,19 +6,18 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <typename DPtrsGlobal,
typename AElementwiseOperation,
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation>
typename DxsReduceAccElementwiseOperation>
struct DeviceGemmReduce : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
DPtrsGlobal p_dxs,
void* p_dxs,
ck::index_t M,
ck::index_t N,
ck::index_t K,
......@@ -29,24 +28,69 @@ struct DeviceGemmReduce : public BaseOperator
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename DPtrsGlobal,
typename AElementwiseOperation,
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation>
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<DPtrsGlobal,
AElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsAccElementwiseOperation>>;
DxsReduceAccElementwiseOperation>>;
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
struct DeviceGemmBiasAddReduce : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const void* p_c0,
const void* p_c1,
void* p_dxs,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t StrideC1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
C1ElementwiseOperation c1_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsReduceAccElementwiseOperation>
using DeviceGemmBiasAddReducePtr =
std::unique_ptr<DeviceGemmBiasAddReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
C1ElementwiseOperation,
DxsInElementwiseOperation,
DxsReduceAccElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
......
......@@ -32,7 +32,7 @@ template <typename ALayout,
typename CElementwiseOperation,
typename DxsReduceOperation,
typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation,
typename DxsReduceAccElementwiseOperation,
typename DGlobalMemoryDataOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
......@@ -68,12 +68,11 @@ template <typename ALayout,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
AElementwiseOperation,
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsAccElementwiseOperation>
DxsReduceAccElementwiseOperation>
{
using DeviceOp = DeviceGemmReduce_Xdl_CShuffle;
......@@ -389,7 +388,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
CElementwiseOperation,
DxsReduceOperation,
DxsInElementwiseOperation,
DxsAccElementwiseOperation,
DxsReduceAccElementwiseOperation,
InMemoryDataOperationEnum::Set,
DGlobalMemoryDataOperation,
AGridDesc_AK0_M_AK1,
......@@ -449,7 +448,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op)
DxsReduceAccElementwiseOperation dxs_out_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
......@@ -498,7 +497,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
DxsInElementwiseOperation dxs_in_element_op_;
DxsAccElementwiseOperation dxs_out_element_op_;
DxsReduceAccElementwiseOperation dxs_out_element_op_;
};
// Invoker
......@@ -554,7 +553,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsAccElementwiseOperation,
DxsReduceAccElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -594,7 +593,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsAccElementwiseOperation,
DxsReduceAccElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -669,7 +668,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op)
DxsReduceAccElementwiseOperation dxs_out_element_op)
{
return Argument{p_a,
p_b,
......@@ -691,10 +690,11 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
DPtrsGlobal p_dxs,
void* p_dxs,
index_t MRaw,
index_t NRaw,
index_t KRaw,
......@@ -705,13 +705,14 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op,
DxsReduceAccElementwiseOperation dxs_out_element_op,
index_t /* KBatch */ = 1) override
{
DPtrsGlobal dxs_tuple = *(static_cast<DPtrsGlobal*>(p_dxs));
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
p_dxs,
dxs_tuple,
MRaw,
NRaw,
KRaw,
......
......@@ -144,6 +144,27 @@ struct AddHardswishAdd
}
};
struct Relu
{
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
y = x > 0 ? x : 0;
}
template <>
__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const
{
float x_f32 = ck::type_convert<float>(x);
float y_f32 = x_f32 > 0 ? x_f32 : 0;
y = ck::type_convert<bhalf_t>(y_f32);
}
};
struct Normalize
{
Normalize(float epsilon = 1e-4) : epsilon_(epsilon) {}
......
......@@ -21,7 +21,7 @@ template <typename GridwiseGemm,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation,
typename DxsReduceAccElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
......@@ -41,7 +41,7 @@ __global__ void
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const DxsInElementwiseOperation dxs_in_element_op,
const DxsAccElementwiseOperation dxs_out_element_op,
const DxsReduceAccElementwiseOperation dxs_out_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -96,7 +96,7 @@ template <typename FloatAB,
typename CElementwiseOperation,
typename DxsReduceOperation,
typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation,
typename DxsReduceAccElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename DGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
......@@ -329,7 +329,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const DxsInElementwiseOperation& dxs_in_element_op,
const DxsAccElementwiseOperation& dxs_out_element_op,
const DxsReduceAccElementwiseOperation& dxs_out_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
......
......@@ -30,6 +30,7 @@ add_subdirectory(gemm_bias2d)
add_subdirectory(gemm_bias_relu)
add_subdirectory(gemm_bias_relu_add)
add_subdirectory(gemm_reduce)
add_subdirectory(gemm_bias_add_reduce)
add_subdirectory(batched_gemm)
add_subdirectory(conv1d_fwd)
add_subdirectory(conv2d_fwd)
......
......@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_in
>;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal,
PassThrough,
PassThrough,
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
std::vector<
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
instances)
{
add_device_operation_instances(
instances,
......
......@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_in
>;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal,
PassThrough,
PassThrough,
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
std::vector<
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
instances)
{
add_device_operation_instances(
instances,
......
......@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_in
>;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal,
PassThrough,
PassThrough,
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
std::vector<
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
instances)
{
add_device_operation_instances(
instances,
......
......@@ -59,12 +59,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_in
>;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal,
PassThrough,
PassThrough,
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
std::vector<
DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
instances)
{
add_device_operation_instances(
instances,
......
set(DEVICE_GEMM_REDUCE_INSTANCE_SOURCE
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp
)
add_instance_library(device_gemm_bias_add_reduce_instance ${DEVICE_GEMM_REDUCE_INSTANCE_SOURCE})
install(TARGETS device_gemm_bias_add_reduce_instance LIBRARY DESTINATION lib)
clang_tidy_check(device_gemm_bias_add_reduce_instance)
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_bias_add_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reduction_operator.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, true>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Div, Div>;
using ReduceMemOp = ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// c[m, n] = a[k, m] * b[k, n]
using device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>,
DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>
// clang-format on
>;
void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmBiasAddReducePtr<PassThrough,
PassThrough,
PassThrough,
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment