Commit 300ac4e4 authored by rocking's avatar rocking
Browse files

Implement gemm bias add reduction

parent 09a2b547
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_fp16 gemm_bias_relu_add_layernorm_xdl_fp16.cpp)
add_example_executable(example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp) add_example_executable(example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp)
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_reduce_xdl_cshuffle_v1.hpp" #include "gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
namespace ck { namespace ck {
...@@ -23,6 +23,8 @@ template <typename ALayout, ...@@ -23,6 +23,8 @@ template <typename ALayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename C0DataType,
typename C1DataType,
typename GemmAccDataType, typename GemmAccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename ReduceAccDataType, typename ReduceAccDataType,
...@@ -68,14 +70,15 @@ template <typename ALayout, ...@@ -68,14 +70,15 @@ template <typename ALayout,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, struct DeviceGemmBiasAddReduce_Xdl_CShuffle
AElementwiseOperation, : public DeviceGemmBiasAddReduce<DPtrsGlobal,
BElementwiseOperation, AElementwiseOperation,
CElementwiseOperation, BElementwiseOperation,
DxsInElementwiseOperation, CElementwiseOperation,
DxsAccElementwiseOperation> DxsInElementwiseOperation,
DxsAccElementwiseOperation>
{ {
using DeviceOp = DeviceGemmReduce_Xdl_CShuffle; using DeviceOp = DeviceGemmBiasAddReduce_Xdl_CShuffle;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -374,14 +377,18 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -374,14 +377,18 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using C0GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 0));
using C1GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1)); using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< using GridwiseGemm = GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
C0DataType,
C1DataType,
ReduceAccDataType, ReduceAccDataType,
DPtrsGlobal, DPtrsGlobal,
AElementwiseOperation, AElementwiseOperation,
...@@ -395,6 +402,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -395,6 +402,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
CGridDesc_M_N, CGridDesc_M_N,
C0GridDesc_M_N,
C1GridDesc_M_N,
DGridDesc_M, DGridDesc_M,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
...@@ -438,6 +447,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -438,6 +447,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
Argument(const ADataType* p_a_grid, Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid, const BDataType* p_b_grid,
CDataType* p_c_grid, CDataType* p_c_grid,
const C0DataType* p_c0_grid,
const C1DataType* p_c1_grid,
DPtrsGlobal p_ds_grid, DPtrsGlobal p_ds_grid,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
...@@ -445,6 +456,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -445,6 +456,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
index_t StrideC1,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
...@@ -453,12 +465,18 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -453,12 +465,18 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
p_c0_grid_{p_c0_grid},
p_c1_grid_{p_c1_grid},
p_ds_grid_{p_ds_grid}, p_ds_grid_{p_ds_grid},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
c0_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, 0)},
c1_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC1)},
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)}, d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{},
c0_grid_desc_mblock_mperblock_nblock_nperblock_{},
c1_grid_desc_mblock_mperblock_nblock_nperblock_{},
d_grid_desc_mblock_mperblock_{}, d_grid_desc_mblock_mperblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
...@@ -476,6 +494,14 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -476,6 +494,14 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_); c_grid_desc_m_n_);
c0_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c0_grid_desc_m_n_);
c1_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c1_grid_desc_m_n_);
d_grid_desc_mblock_mperblock_ = d_grid_desc_mblock_mperblock_ =
GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_); GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_);
} }
...@@ -485,13 +511,21 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -485,13 +511,21 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
const C0DataType* p_c0_grid_;
const C1DataType* p_c1_grid_;
DPtrsGlobal p_ds_grid_; DPtrsGlobal p_ds_grid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
C0GridDesc_M_N c0_grid_desc_m_n_;
C1GridDesc_M_N c1_grid_desc_m_n_;
DGridDesc_M d_grid_desc_m_; DGridDesc_M d_grid_desc_m_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c0_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_; typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -508,26 +542,6 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -508,26 +542,6 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.d_grid_desc_m_{ " << arg.d_grid_desc_m_.GetLength(I0) << "}"
<< std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
...@@ -545,10 +559,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -545,10 +559,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
float elapsed_time = 0.0f; float elapsed_time = 0.0f;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1< const auto kernel = kernel_gemm_bias_add_reduce_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
C0DataType,
C1DataType,
DPtrsGlobal, DPtrsGlobal,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -558,6 +574,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -558,6 +574,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock, typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
true>; true>;
...@@ -571,6 +589,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -571,6 +589,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_c0_grid_,
arg.p_c1_grid_,
arg.p_ds_grid_, arg.p_ds_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
...@@ -580,15 +600,19 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -580,15 +600,19 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_, arg.d_grid_desc_mblock_mperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
else else
{ {
const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1< const auto kernel = kernel_gemm_bias_add_reduce_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
C0DataType,
C1DataType,
DPtrsGlobal, DPtrsGlobal,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -598,6 +622,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -598,6 +622,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock, typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
false>; false>;
...@@ -611,6 +637,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -611,6 +637,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_c0_grid_,
arg.p_c1_grid_,
arg.p_ds_grid_, arg.p_ds_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
...@@ -620,6 +648,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -620,6 +648,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_, arg.d_grid_desc_mblock_mperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
...@@ -658,6 +688,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -658,6 +688,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b, const BDataType* p_b,
CDataType* p_c, CDataType* p_c,
const C0DataType* p_c0,
const C1DataType* p_c1,
DPtrsGlobal p_dxs, DPtrsGlobal p_dxs,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
...@@ -665,6 +697,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -665,6 +697,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
index_t StrideC1,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
...@@ -674,6 +707,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -674,6 +707,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
return Argument{p_a, return Argument{p_a,
p_b, p_b,
p_c, p_c,
p_c0,
p_c1,
p_dxs, p_dxs,
MRaw, MRaw,
NRaw, NRaw,
...@@ -681,6 +716,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -681,6 +716,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
StrideC1,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
...@@ -694,6 +730,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -694,6 +730,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
void* p_c, void* p_c,
const void* p_c0,
const void* p_c1,
DPtrsGlobal p_dxs, DPtrsGlobal p_dxs,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
...@@ -701,6 +739,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -701,6 +739,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
index_t StrideC1,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
...@@ -711,6 +750,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -711,6 +750,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
static_cast<const C0DataType*>(p_c0),
static_cast<const C1DataType*>(p_c1),
p_dxs, p_dxs,
MRaw, MRaw,
NRaw, NRaw,
...@@ -718,6 +759,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -718,6 +759,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
StrideC1,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
......
...@@ -48,6 +48,52 @@ using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<DPtrsGlobal, ...@@ -48,6 +48,52 @@ using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<DPtrsGlobal,
DxsInElementwiseOperation, DxsInElementwiseOperation,
DxsAccElementwiseOperation>>; DxsAccElementwiseOperation>>;
template <typename DPtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation>
struct DeviceGemmBiasAddReduce : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const void* p_c0,
const void* p_c1,
DPtrsGlobal p_dxs,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t StrideC1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op,
ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename DPtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation>
using DeviceGemmBiasAddReducePtr =
std::unique_ptr<DeviceGemmBiasAddReduce<DPtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsAccElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -143,6 +143,21 @@ struct AddHardswishAdd ...@@ -143,6 +143,21 @@ struct AddHardswishAdd
} }
}; };
struct Relu
{
__host__ __device__ void operator()(float& y, const float& x) const { y = x > 0 ? x : 0; }
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x > 0 ? x : 0; }
__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { y = x > 0 ? x : 0; }
__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x > 0 ? x : 0; }
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x > 0 ? x : 0; }
__host__ __device__ void operator()(double& y, const double& x) const { y = x > 0 ? x : 0; }
};
struct Normalize struct Normalize
{ {
Normalize(float epsilon = 1e-4) : epsilon_(epsilon) {} Normalize(float epsilon = 1e-4) : epsilon_(epsilon) {}
......
...@@ -16,6 +16,8 @@ namespace ck { ...@@ -16,6 +16,8 @@ namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename FloatC0,
typename FloatC1,
typename DPtrsGlobal, typename DPtrsGlobal,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
...@@ -25,6 +27,8 @@ template <typename GridwiseGemm, ...@@ -25,6 +27,8 @@ template <typename GridwiseGemm,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename DGridDescriptor_MBlock_MPerBlock, typename DGridDescriptor_MBlock_MPerBlock,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
...@@ -32,10 +36,12 @@ __global__ void ...@@ -32,10 +36,12 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_reduce_xdl_cshuffle_v1( kernel_gemm_bias_add_reduce_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const FloatC0* __restrict__ p_c0_grid,
const FloatC1* __restrict__ p_c1_grid,
DPtrsGlobal p_ds_grid, DPtrsGlobal p_ds_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
...@@ -46,6 +52,10 @@ __global__ void ...@@ -46,6 +52,10 @@ __global__ void
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c0_grid_desc_mblock_mperblock_nblock_nperblock,
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock, const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
...@@ -55,6 +65,8 @@ __global__ void ...@@ -55,6 +65,8 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_c0_grid,
p_c1_grid,
p_ds_grid, p_ds_grid,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -65,12 +77,16 @@ __global__ void ...@@ -65,12 +77,16 @@ __global__ void
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
c0_grid_desc_mblock_mperblock_nblock_nperblock,
c1_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_mblock_mperblock, d_grid_desc_mblock_mperblock,
block_2_ctile_map); block_2_ctile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = p_c0_grid;
ignore = p_c1_grid;
ignore = p_ds_grid; ignore = p_ds_grid;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
...@@ -80,6 +96,8 @@ __global__ void ...@@ -80,6 +96,8 @@ __global__ void
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = c0_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = d_grid_desc_mblock_mperblock; ignore = d_grid_desc_mblock_mperblock;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
...@@ -89,6 +107,8 @@ template <typename FloatAB, ...@@ -89,6 +107,8 @@ template <typename FloatAB,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
typename FloatC, typename FloatC,
typename FloatC0,
typename FloatC1,
typename FloatReduceAcc, typename FloatReduceAcc,
typename DPtrsGlobal, typename DPtrsGlobal,
typename AElementwiseOperation, typename AElementwiseOperation,
...@@ -102,6 +122,8 @@ template <typename FloatAB, ...@@ -102,6 +122,8 @@ template <typename FloatAB,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
typename C0GridDesc_M_N,
typename C1GridDesc_M_N,
typename DGridDesc_M, typename DGridDesc_M,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
...@@ -138,7 +160,7 @@ template <typename FloatAB, ...@@ -138,7 +160,7 @@ template <typename FloatAB,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched> LoopScheduler LoopSched>
struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -268,8 +290,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -268,8 +290,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
} }
template <typename CGridDesc_M_N_>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N_& c_grid_desc_m_n)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0); const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1); const auto N = c_grid_desc_m_n.GetLength(I1);
...@@ -313,6 +336,12 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -313,6 +336,12 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C0GridDesc_M_N{}))>;
using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C1GridDesc_M_N{}))>;
using DGridDescriptor_MBlock_MPerBlock = using DGridDescriptor_MBlock_MPerBlock =
remove_cvref_t<decltype(MakeDGridDescriptor_MBlock_MPerBlock(DGridDesc_M{}))>; remove_cvref_t<decltype(MakeDGridDescriptor_MBlock_MPerBlock(DGridDesc_M{}))>;
...@@ -323,6 +352,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -323,6 +352,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const FloatC0* __restrict__ p_c0_grid,
const FloatC1* __restrict__ p_c1_grid,
DPtrsGlobal p_ds_grid, DPtrsGlobal p_ds_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
...@@ -334,6 +365,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -334,6 +365,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c0_grid_desc_mblock_mperblock_nblock_nperblock,
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c1_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_MBlock_MPerBlock& d_grid_desc_mblock_mperblock, const DGridDescriptor_MBlock_MPerBlock& d_grid_desc_mblock_mperblock,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
...@@ -343,6 +378,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -343,6 +378,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto c0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c0_grid, c0_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto c1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c1_grid, c1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
...@@ -610,32 +649,6 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -610,32 +649,6 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
n_thread_data_on_block_idx[I2]), n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}}; ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
FloatC, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_element_op};
// space filling curve for threadwise C in VGPR // space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr = constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>, SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
...@@ -759,14 +772,80 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -759,14 +772,80 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}, },
Number<p_ds_grid.Size()>{}); Number<p_ds_grid.Size()>{});
// c0 and c1
constexpr auto c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<mreduce_per_thread>{}, I1, Number<nreduce_per_thread>{}));
constexpr auto c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock;
auto c01_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatC0,
FloatReduceAcc,
decltype(c0_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>,
Sequence<0, 1, 2, 3>,
3,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
1,
true>(
c0_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(I0,
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
I0,
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]));
auto c1_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatC1,
FloatReduceAcc,
decltype(c1_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>,
Sequence<0, 1, 2, 3>,
3,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
1,
true>(
c1_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(I0,
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
I0,
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]));
constexpr auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<mreduce_per_thread>{}, I1, Number<nreduce_per_thread>{}));
auto c_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
FloatReduceAcc,
FloatC,
decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
tensor_operation::element_wise::PassThrough,
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
3, // DstVectorDim
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
InMemoryDataOperationEnum::Set,
1,
true>{
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(I0,
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
I0,
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]),
tensor_operation::element_wise::PassThrough{}};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) { static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS // each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
...@@ -774,17 +853,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -774,17 +853,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf); c_shuffle_block_buf);
// make sure it's safe to read from LDS // make sure it's safe to write to LDS
block_sync_lds(); block_sync_lds();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
// TODO - extract following into reduction_blockwise
{ {
c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
c_shuffle_block_buf, c_shuffle_block_buf,
...@@ -792,6 +862,37 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -792,6 +862,37 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_tuple(I0, I0), make_tuple(I0, I0),
c_reduce_thread_buf); c_reduce_thread_buf);
c0_thread_copy_global_to_vgpr.Run(
c0_grid_desc_mblock_mperblock_nblock_nperblock,
c0_grid_buf,
c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
c01_thread_buf);
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) {
FloatReduceAcc out;
c_element_op(out, c_reduce_thread_buf(i) + c01_thread_buf(i));
c_reduce_thread_buf(i) = out;
});
c1_thread_copy_global_to_vgpr.Run(
c1_grid_desc_mblock_mperblock_nblock_nperblock,
c1_grid_buf,
c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
c01_thread_buf);
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) { c_reduce_thread_buf(i) += c01_thread_buf(i); });
c_reduce_thread_copy_vgpr_to_global.Run(
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
c_reduce_thread_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) { static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) {
auto& p_d_grid = p_ds_grid[In]; auto& p_d_grid = p_ds_grid[In];
...@@ -858,13 +959,19 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -858,13 +959,19 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C // move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( c_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
// move on C0
c0_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
c0_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
// move on C1
c1_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
c1_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
} }
}); });
} // Reduction
// Reduction
}
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment