Commit c44818e7 authored by rocking's avatar rocking
Browse files

Rename DxsReduceAccElementwiseOperation to DxsReduceAccElementwiseOperation

parent 46eca0a1
...@@ -22,7 +22,7 @@ template <typename GridwiseGemm, ...@@ -22,7 +22,7 @@ template <typename GridwiseGemm,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename DxsInElementwiseOperation, typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation, typename DxsReduceAccElementwiseOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -44,7 +44,7 @@ __global__ void ...@@ -44,7 +44,7 @@ __global__ void
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const DxsInElementwiseOperation dxs_in_element_op, const DxsInElementwiseOperation dxs_in_element_op,
const DxsAccElementwiseOperation dxs_out_element_op, const DxsReduceAccElementwiseOperation dxs_out_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -126,7 +126,7 @@ template <typename ALayout, ...@@ -126,7 +126,7 @@ template <typename ALayout,
typename CElementwiseOperation, typename CElementwiseOperation,
typename DxsReduceOperation, typename DxsReduceOperation,
typename DxsInElementwiseOperation, typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation, typename DxsReduceAccElementwiseOperation,
typename DGlobalMemoryDataOperation, typename DGlobalMemoryDataOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
...@@ -162,12 +162,13 @@ template <typename ALayout, ...@@ -162,12 +162,13 @@ template <typename ALayout,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, struct DeviceBatchedGemmReduce_Xdl_CShuffle
: public DeviceGemmReduce<DPtrsGlobal,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DxsInElementwiseOperation, DxsInElementwiseOperation,
DxsAccElementwiseOperation> DxsReduceAccElementwiseOperation>
{ {
using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle; using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle;
...@@ -527,7 +528,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba ...@@ -527,7 +528,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
CElementwiseOperation, CElementwiseOperation,
DxsReduceOperation, DxsReduceOperation,
DxsInElementwiseOperation, DxsInElementwiseOperation,
DxsAccElementwiseOperation, DxsReduceAccElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
DGlobalMemoryDataOperation, DGlobalMemoryDataOperation,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
...@@ -587,7 +588,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba ...@@ -587,7 +588,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op, DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op, DxsReduceAccElementwiseOperation dxs_out_element_op,
index_t BatchCount) index_t BatchCount)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
...@@ -645,7 +646,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba ...@@ -645,7 +646,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
DxsInElementwiseOperation dxs_in_element_op_; DxsInElementwiseOperation dxs_in_element_op_;
DxsAccElementwiseOperation dxs_out_element_op_; DxsReduceAccElementwiseOperation dxs_out_element_op_;
}; };
// Invoker // Invoker
...@@ -703,7 +704,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba ...@@ -703,7 +704,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DxsInElementwiseOperation, DxsInElementwiseOperation,
DxsAccElementwiseOperation, DxsReduceAccElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -746,7 +747,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba ...@@ -746,7 +747,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DxsInElementwiseOperation, DxsInElementwiseOperation,
DxsAccElementwiseOperation, DxsReduceAccElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -832,7 +833,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba ...@@ -832,7 +833,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op, DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op, DxsReduceAccElementwiseOperation dxs_out_element_op,
index_t BatchCount) index_t BatchCount)
{ {
return Argument{p_a, return Argument{p_a,
...@@ -856,7 +857,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba ...@@ -856,7 +857,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
void* p_c, void* p_c,
DPtrsGlobal p_dxs, DPtrsGlobal p_dxs,
...@@ -870,7 +872,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba ...@@ -870,7 +872,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGloba
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op, DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op, DxsReduceAccElementwiseOperation dxs_out_element_op,
index_t BatchCount) override index_t BatchCount) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
......
...@@ -35,7 +35,7 @@ template <typename ALayout, ...@@ -35,7 +35,7 @@ template <typename ALayout,
typename C1ElementwiseOperation, typename C1ElementwiseOperation,
typename DxsReduceOperation, typename DxsReduceOperation,
typename DxsInElementwiseOperation, typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation, typename DxsReduceAccElementwiseOperation,
typename DGlobalMemoryDataOperation, typename DGlobalMemoryDataOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
...@@ -78,7 +78,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -78,7 +78,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
CElementwiseOperation, CElementwiseOperation,
C1ElementwiseOperation, C1ElementwiseOperation,
DxsInElementwiseOperation, DxsInElementwiseOperation,
DxsAccElementwiseOperation> DxsReduceAccElementwiseOperation>
{ {
using DeviceOp = DeviceGemmBiasAddReduce_Xdl_CShuffle; using DeviceOp = DeviceGemmBiasAddReduce_Xdl_CShuffle;
...@@ -399,7 +399,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -399,7 +399,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
C1ElementwiseOperation, C1ElementwiseOperation,
DxsReduceOperation, DxsReduceOperation,
DxsInElementwiseOperation, DxsInElementwiseOperation,
DxsAccElementwiseOperation, DxsReduceAccElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
DGlobalMemoryDataOperation, DGlobalMemoryDataOperation,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
...@@ -465,7 +465,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -465,7 +465,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
C1ElementwiseOperation c1_element_op, C1ElementwiseOperation c1_element_op,
DxsInElementwiseOperation dxs_in_element_op, DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op) DxsReduceAccElementwiseOperation dxs_out_element_op)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
...@@ -538,7 +538,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -538,7 +538,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
C1ElementwiseOperation c1_element_op_; C1ElementwiseOperation c1_element_op_;
DxsInElementwiseOperation dxs_in_element_op_; DxsInElementwiseOperation dxs_in_element_op_;
DxsAccElementwiseOperation dxs_out_element_op_; DxsReduceAccElementwiseOperation dxs_out_element_op_;
}; };
// Invoker // Invoker
...@@ -577,7 +577,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -577,7 +577,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
CElementwiseOperation, CElementwiseOperation,
C1ElementwiseOperation, C1ElementwiseOperation,
DxsInElementwiseOperation, DxsInElementwiseOperation,
DxsAccElementwiseOperation, DxsReduceAccElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -627,7 +627,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -627,7 +627,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
CElementwiseOperation, CElementwiseOperation,
C1ElementwiseOperation, C1ElementwiseOperation,
DxsInElementwiseOperation, DxsInElementwiseOperation,
DxsAccElementwiseOperation, DxsReduceAccElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -713,7 +713,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -713,7 +713,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
C1ElementwiseOperation c1_element_op, C1ElementwiseOperation c1_element_op,
DxsInElementwiseOperation dxs_in_element_op, DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op) DxsReduceAccElementwiseOperation dxs_out_element_op)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -739,7 +739,8 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -739,7 +739,8 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
void* p_c, void* p_c,
const void* p_c0, const void* p_c0,
...@@ -757,7 +758,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -757,7 +758,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
C1ElementwiseOperation c1_element_op, C1ElementwiseOperation c1_element_op,
DxsInElementwiseOperation dxs_in_element_op, DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op, DxsReduceAccElementwiseOperation dxs_out_element_op,
index_t /* KBatch */ = 1) override index_t /* KBatch */ = 1) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
......
...@@ -11,7 +11,7 @@ template <typename DPtrsGlobal, ...@@ -11,7 +11,7 @@ template <typename DPtrsGlobal,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename DxsInElementwiseOperation, typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation> typename DxsReduceAccElementwiseOperation>
struct DeviceGemmReduce : public BaseOperator struct DeviceGemmReduce : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
...@@ -29,7 +29,7 @@ struct DeviceGemmReduce : public BaseOperator ...@@ -29,7 +29,7 @@ struct DeviceGemmReduce : public BaseOperator
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op, DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op, DxsReduceAccElementwiseOperation dxs_out_element_op,
ck::index_t BatchCount = 1) = 0; ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
...@@ -40,13 +40,13 @@ template <typename DPtrsGlobal, ...@@ -40,13 +40,13 @@ template <typename DPtrsGlobal,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename DxsInElementwiseOperation, typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation> typename DxsReduceAccElementwiseOperation>
using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<DPtrsGlobal, using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<DPtrsGlobal,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DxsInElementwiseOperation, DxsInElementwiseOperation,
DxsAccElementwiseOperation>>; DxsReduceAccElementwiseOperation>>;
template <typename DPtrsGlobal, template <typename DPtrsGlobal,
typename AElementwiseOperation, typename AElementwiseOperation,
...@@ -54,7 +54,7 @@ template <typename DPtrsGlobal, ...@@ -54,7 +54,7 @@ template <typename DPtrsGlobal,
typename CElementwiseOperation, typename CElementwiseOperation,
typename C1ElementwiseOperation, typename C1ElementwiseOperation,
typename DxsInElementwiseOperation, typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation> typename DxsReduceAccElementwiseOperation>
struct DeviceGemmBiasAddReduce : public BaseOperator struct DeviceGemmBiasAddReduce : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
...@@ -76,7 +76,7 @@ struct DeviceGemmBiasAddReduce : public BaseOperator ...@@ -76,7 +76,7 @@ struct DeviceGemmBiasAddReduce : public BaseOperator
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
C1ElementwiseOperation c1_element_op, C1ElementwiseOperation c1_element_op,
DxsInElementwiseOperation dxs_in_element_op, DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op, DxsReduceAccElementwiseOperation dxs_out_element_op,
ck::index_t BatchCount = 1) = 0; ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
...@@ -88,7 +88,7 @@ template <typename DPtrsGlobal, ...@@ -88,7 +88,7 @@ template <typename DPtrsGlobal,
typename CElementwiseOperation, typename CElementwiseOperation,
typename C1ElementwiseOperation, typename C1ElementwiseOperation,
typename DxsInElementwiseOperation, typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation> typename DxsReduceAccElementwiseOperation>
using DeviceGemmBiasAddReducePtr = using DeviceGemmBiasAddReducePtr =
std::unique_ptr<DeviceGemmBiasAddReduce<DPtrsGlobal, std::unique_ptr<DeviceGemmBiasAddReduce<DPtrsGlobal,
AElementwiseOperation, AElementwiseOperation,
...@@ -96,7 +96,7 @@ using DeviceGemmBiasAddReducePtr = ...@@ -96,7 +96,7 @@ using DeviceGemmBiasAddReducePtr =
CElementwiseOperation, CElementwiseOperation,
C1ElementwiseOperation, C1ElementwiseOperation,
DxsInElementwiseOperation, DxsInElementwiseOperation,
DxsAccElementwiseOperation>>; DxsReduceAccElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -32,7 +32,7 @@ template <typename ALayout, ...@@ -32,7 +32,7 @@ template <typename ALayout,
typename CElementwiseOperation, typename CElementwiseOperation,
typename DxsReduceOperation, typename DxsReduceOperation,
typename DxsInElementwiseOperation, typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation, typename DxsReduceAccElementwiseOperation,
typename DGlobalMemoryDataOperation, typename DGlobalMemoryDataOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
...@@ -73,7 +73,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -73,7 +73,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DxsInElementwiseOperation, DxsInElementwiseOperation,
DxsAccElementwiseOperation> DxsReduceAccElementwiseOperation>
{ {
using DeviceOp = DeviceGemmReduce_Xdl_CShuffle; using DeviceOp = DeviceGemmReduce_Xdl_CShuffle;
...@@ -389,7 +389,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -389,7 +389,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
CElementwiseOperation, CElementwiseOperation,
DxsReduceOperation, DxsReduceOperation,
DxsInElementwiseOperation, DxsInElementwiseOperation,
DxsAccElementwiseOperation, DxsReduceAccElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
DGlobalMemoryDataOperation, DGlobalMemoryDataOperation,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
...@@ -449,7 +449,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -449,7 +449,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op, DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op) DxsReduceAccElementwiseOperation dxs_out_element_op)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
...@@ -498,7 +498,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -498,7 +498,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
DxsInElementwiseOperation dxs_in_element_op_; DxsInElementwiseOperation dxs_in_element_op_;
DxsAccElementwiseOperation dxs_out_element_op_; DxsReduceAccElementwiseOperation dxs_out_element_op_;
}; };
// Invoker // Invoker
...@@ -554,7 +554,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -554,7 +554,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DxsInElementwiseOperation, DxsInElementwiseOperation,
DxsAccElementwiseOperation, DxsReduceAccElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -594,7 +594,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -594,7 +594,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DxsInElementwiseOperation, DxsInElementwiseOperation,
DxsAccElementwiseOperation, DxsReduceAccElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -669,7 +669,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -669,7 +669,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op, DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op) DxsReduceAccElementwiseOperation dxs_out_element_op)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -691,7 +691,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -691,7 +691,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
void* p_c, void* p_c,
DPtrsGlobal p_dxs, DPtrsGlobal p_dxs,
...@@ -705,7 +706,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -705,7 +706,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op, DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op, DxsReduceAccElementwiseOperation dxs_out_element_op,
index_t /* KBatch */ = 1) override index_t /* KBatch */ = 1) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
......
...@@ -24,7 +24,7 @@ template <typename GridwiseGemm, ...@@ -24,7 +24,7 @@ template <typename GridwiseGemm,
typename CElementwiseOperation, typename CElementwiseOperation,
typename C1ElementwiseOperation, typename C1ElementwiseOperation,
typename DxsInElementwiseOperation, typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation, typename DxsReduceAccElementwiseOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -49,7 +49,7 @@ __global__ void ...@@ -49,7 +49,7 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const C1ElementwiseOperation c1_element_op, const C1ElementwiseOperation c1_element_op,
const DxsInElementwiseOperation dxs_in_element_op, const DxsInElementwiseOperation dxs_in_element_op,
const DxsAccElementwiseOperation dxs_out_element_op, const DxsReduceAccElementwiseOperation dxs_out_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -121,7 +121,7 @@ template <typename FloatAB, ...@@ -121,7 +121,7 @@ template <typename FloatAB,
typename C1ElementwiseOperation, typename C1ElementwiseOperation,
typename DxsReduceOperation, typename DxsReduceOperation,
typename DxsInElementwiseOperation, typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation, typename DxsReduceAccElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename DGlobalMemoryDataOperation, typename DGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
...@@ -366,7 +366,7 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -366,7 +366,7 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const C1ElementwiseOperation& c1_element_op, const C1ElementwiseOperation& c1_element_op,
const DxsInElementwiseOperation& dxs_in_element_op, const DxsInElementwiseOperation& dxs_in_element_op,
const DxsAccElementwiseOperation& dxs_out_element_op, const DxsReduceAccElementwiseOperation& dxs_out_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
......
...@@ -21,7 +21,7 @@ template <typename GridwiseGemm, ...@@ -21,7 +21,7 @@ template <typename GridwiseGemm,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename DxsInElementwiseOperation, typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation, typename DxsReduceAccElementwiseOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -41,7 +41,7 @@ __global__ void ...@@ -41,7 +41,7 @@ __global__ void
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const DxsInElementwiseOperation dxs_in_element_op, const DxsInElementwiseOperation dxs_in_element_op,
const DxsAccElementwiseOperation dxs_out_element_op, const DxsReduceAccElementwiseOperation dxs_out_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -96,7 +96,7 @@ template <typename FloatAB, ...@@ -96,7 +96,7 @@ template <typename FloatAB,
typename CElementwiseOperation, typename CElementwiseOperation,
typename DxsReduceOperation, typename DxsReduceOperation,
typename DxsInElementwiseOperation, typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation, typename DxsReduceAccElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename DGlobalMemoryDataOperation, typename DGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
...@@ -329,7 +329,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -329,7 +329,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const DxsInElementwiseOperation& dxs_in_element_op, const DxsInElementwiseOperation& dxs_in_element_op,
const DxsAccElementwiseOperation& dxs_out_element_op, const DxsReduceAccElementwiseOperation& dxs_out_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment