Commit ddc76a8b authored by Anthony Chang's avatar Anthony Chang
Browse files

AccElemOp for gemm outputs prior to feeding to layernorm

parent 12db5b6d
......@@ -36,20 +36,29 @@ using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
struct Relu
{
template<typename OutT, typename InT>
__host__ __device__ void operator()(OutT& y, const InT& x) const
{
y = x > 0 ? x : 0;
}
};
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
// using AccElementOp = ck::tensor_operation::element_wise::PassThrough;
using AccElementOp = Relu;
using CElementOp = Relu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmLayerNorm_Xdl_CShuffle
//######| ALayout| BLayout| CLayout| AData| BData| CData| C0Data| GemmAcc| CShuffle| ReduceAcc| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| Type| DataType| DataType| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, ADataType, BDataType, CDataType, C0DataType, AccDataType, AccDataType, AccDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
//######| ALayout| BLayout| CLayout| AData| BData| CData| C0Data| GemmAcc| CShuffle| ReduceAcc| A| B| Acc| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| Type| DataType| DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, ADataType, BDataType, CDataType, C0DataType, AccDataType, AccDataType, AccDataType, AElementOp, BElementOp, AccElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceInstance = ck::tensor_operation::host::ReferenceGemmLayernorm<ADataType,
......@@ -59,6 +68,7 @@ using ReferenceInstance = ck::tensor_operation::host::ReferenceGemmLayernorm<ADa
AccDataType,
AElementOp,
BElementOp,
AccElementOp,
CElementOp>;
int main(int argc, char* argv[])
......@@ -176,6 +186,7 @@ int main(int argc, char* argv[])
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto acc_element_op = AccElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
......@@ -195,6 +206,7 @@ int main(int argc, char* argv[])
StrideC,
a_element_op,
b_element_op,
acc_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
......@@ -235,6 +247,7 @@ int main(int argc, char* argv[])
c_m_n_host_result,
a_element_op,
b_element_op,
acc_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
......
......@@ -33,6 +33,7 @@ template <typename ALayout,
typename ReduceAccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
......@@ -380,6 +381,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
ReduceAccDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
......@@ -439,6 +441,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
......@@ -455,6 +458,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
block_2_ctile_map_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
c_element_op_{c_element_op}
{
if(GridwiseGemm::CheckValidity(
......@@ -489,6 +493,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
AccElementwiseOperation acc_element_op_;
CElementwiseOperation c_element_op_;
};
......@@ -538,6 +543,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
C0DataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
......@@ -560,6 +566,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
arg.p_c0_beta_,
arg.a_element_op_,
arg.b_element_op_,
arg.acc_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
......@@ -576,6 +583,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
C0DataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
......@@ -597,6 +605,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
arg.p_c0_beta_,
arg.a_element_op_,
arg.b_element_op_,
arg.acc_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
......@@ -648,6 +657,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a,
......@@ -664,6 +674,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
StrideC,
a_element_op,
b_element_op,
acc_element_op,
c_element_op};
}
......@@ -683,6 +694,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
CElementwiseOperation c_element_op,
index_t /* KBatch */ = 1)
{
......@@ -700,6 +712,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
StrideC,
a_element_op,
b_element_op,
acc_element_op,
c_element_op);
}
......
......@@ -21,6 +21,7 @@ template <typename GridwiseGemm,
typename FloatC0,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
......@@ -41,6 +42,7 @@ __global__ void
const FloatC0* __restrict__ p_c0_beta_grid, // 1xN
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
......@@ -62,6 +64,7 @@ __global__ void
p_shared,
a_element_op,
b_element_op,
acc_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
......@@ -79,6 +82,7 @@ __global__ void
ignore = p_c0_beta_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = acc_element_op;
ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
......@@ -99,6 +103,7 @@ template <typename FloatAB,
typename FloatReduceAcc, // Data type after shuffle
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
......@@ -377,6 +382,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const AccElementwiseOperation& acc_element_op,
const CElementwiseOperation& c_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
......@@ -630,7 +636,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
AccElementwiseOperation,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
......@@ -654,7 +660,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
acc_element_op};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
......
......@@ -16,6 +16,7 @@ template <typename ADataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename CElementwiseOperation>
struct ReferenceGemmLayernorm : public device::BaseOperator
{
......@@ -25,7 +26,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
AccDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
AccElementwiseOperation>;
// D = Layernorm(acc + broadcast(bias)) * broadcast(gamma) + broadcast(beta)
template <typename InDataType, typename OutDataType, typename ComputeDataType>
......@@ -95,6 +96,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
CElementwiseOperation c_element_op,
const CDataType epsilon = 1e-5)
: a_m_k_{a_m_k},
......@@ -105,6 +107,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
c_m_n_{c_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
c_element_op_{c_element_op},
epsilon_{epsilon}
{
......@@ -119,7 +122,9 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
AccElementwiseOperation acc_element_op_;
CElementwiseOperation c_element_op_;
const CDataType epsilon_;
};
......@@ -140,11 +145,16 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
acc_m_n,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
arg.acc_element_op_);
ref_invoker.Run(ref_argument);
RunLayernorm(arg.c_m_n_, acc_m_n, arg.c0_n_bias_, arg.c0_n_gamma_, arg.c0_n_beta_);
arg.c_m_n_.ForEach([&](auto& self, auto idx){
arg.c_element_op_(self(idx[0], idx[1]), self(idx[0], idx[1]));
});
return 0;
}
......@@ -171,6 +181,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
CElementwiseOperation c_element_op,
const CDataType epsilon = 1e-5)
{
......@@ -182,6 +193,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
c_m_n,
a_element_op,
b_element_op,
acc_element_op,
c_element_op,
epsilon};
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment