Commit 78f28a13 authored by rocking's avatar rocking
Browse files

Add c1 functor

parent a5842a7f
...@@ -48,6 +48,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -48,6 +48,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CElementOp = ck::tensor_operation::element_wise::Relu; using CElementOp = ck::tensor_operation::element_wise::Relu;
using C1ElementOp = PassThrough;
using ReduceSumOp = ck::reduce::Add<ReduceAccDataType>; using ReduceSumOp = ck::reduce::Add<ReduceAccDataType>;
using DxsReduceOp = ck::Tuple<ReduceSumOp, ReduceSumOp>; using DxsReduceOp = ck::Tuple<ReduceSumOp, ReduceSumOp>;
...@@ -69,11 +70,11 @@ static constexpr auto GemmSpecialization = ...@@ -69,11 +70,11 @@ static constexpr auto GemmSpecialization =
// clang-format off // clang-format off
using DeviceGemmBiasAddReduceInstance = ck::tensor_operation::device::DeviceGemmBiasAddReduce_Xdl_CShuffle using DeviceGemmBiasAddReduceInstance = ck::tensor_operation::device::DeviceGemmBiasAddReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //######| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOp, DxsOutElementOp, DxsGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, C1ElementOp, DxsReduceOp, DxsInElementOp, DxsOutElementOp, DxsGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
...@@ -131,7 +132,8 @@ template <typename CDataType, ...@@ -131,7 +132,8 @@ template <typename CDataType,
typename C1DataType, typename C1DataType,
typename A_functor, typename A_functor,
typename B_functor, typename B_functor,
typename C_functor> typename C_functor,
typename C1_functor>
void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
const Tensor<ADataType>& a_m_k, const Tensor<ADataType>& a_m_k,
const Tensor<ADataType>& b_k_n, const Tensor<ADataType>& b_k_n,
...@@ -142,6 +144,7 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, ...@@ -142,6 +144,7 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
A_functor a_element_op, A_functor a_element_op,
B_functor b_element_op, B_functor b_element_op,
C_functor c_element_op, C_functor c_element_op,
C1_functor c1_element_op,
int M, int M,
int N) int N)
{ {
...@@ -160,14 +163,18 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, ...@@ -160,14 +163,18 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
// c = activation(c + bias) + c1_functor(c1)
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
AccDataType acc = AccDataType acc =
static_cast<AccDataType>(c_m_n(m, n)) + static_cast<AccDataType>(bias_n(n)); static_cast<AccDataType>(c_m_n(m, n)) + static_cast<AccDataType>(bias_n(n));
AccDataType c1 = c1_m_n(m, n);
c_element_op(acc, acc); c_element_op(acc, acc);
acc += static_cast<AccDataType>(c1_m_n(m, n)); c1_element_op(c1, c1);
acc += static_cast<AccDataType>(c1);
c_m_n(m, n) = static_cast<CDataType>(acc); c_m_n(m, n) = static_cast<CDataType>(acc);
} }
...@@ -293,6 +300,7 @@ int main() ...@@ -293,6 +300,7 @@ int main()
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
auto c1_element_op = C1ElementOp{};
auto dxs_global = auto dxs_global =
ck::make_tuple(static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()), ck::make_tuple(static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer())); static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer()));
...@@ -320,6 +328,7 @@ int main() ...@@ -320,6 +328,7 @@ int main()
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
c1_element_op,
dxs_in_element_op, dxs_in_element_op,
dxs_out_element_op); dxs_out_element_op);
...@@ -378,6 +387,7 @@ int main() ...@@ -378,6 +387,7 @@ int main()
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
c1_element_op,
M, M,
N); N);
......
...@@ -32,6 +32,7 @@ template <typename ALayout, ...@@ -32,6 +32,7 @@ template <typename ALayout,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsReduceOperation, typename DxsReduceOperation,
typename DxsInElementwiseOperation, typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation, typename DxsAccElementwiseOperation,
...@@ -75,6 +76,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -75,6 +76,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
C1ElementwiseOperation,
DxsInElementwiseOperation, DxsInElementwiseOperation,
DxsAccElementwiseOperation> DxsAccElementwiseOperation>
{ {
...@@ -394,6 +396,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -394,6 +396,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
C1ElementwiseOperation,
DxsReduceOperation, DxsReduceOperation,
DxsInElementwiseOperation, DxsInElementwiseOperation,
DxsAccElementwiseOperation, DxsAccElementwiseOperation,
...@@ -460,6 +463,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -460,6 +463,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
C1ElementwiseOperation c1_element_op,
DxsInElementwiseOperation dxs_in_element_op, DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op) DxsAccElementwiseOperation dxs_out_element_op)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
...@@ -482,6 +486,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -482,6 +486,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
c1_element_op_{c1_element_op},
dxs_in_element_op_{dxs_in_element_op}, dxs_in_element_op_{dxs_in_element_op},
dxs_out_element_op_{dxs_out_element_op} dxs_out_element_op_{dxs_out_element_op}
{ {
...@@ -531,6 +536,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -531,6 +536,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
C1ElementwiseOperation c1_element_op_;
DxsInElementwiseOperation dxs_in_element_op_; DxsInElementwiseOperation dxs_in_element_op_;
DxsAccElementwiseOperation dxs_out_element_op_; DxsAccElementwiseOperation dxs_out_element_op_;
}; };
...@@ -569,6 +575,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -569,6 +575,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
C1ElementwiseOperation,
DxsInElementwiseOperation, DxsInElementwiseOperation,
DxsAccElementwiseOperation, DxsAccElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
...@@ -595,6 +602,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -595,6 +602,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.c1_element_op_,
arg.dxs_in_element_op_, arg.dxs_in_element_op_,
arg.dxs_out_element_op_, arg.dxs_out_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
...@@ -617,6 +625,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -617,6 +625,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
C1ElementwiseOperation,
DxsInElementwiseOperation, DxsInElementwiseOperation,
DxsAccElementwiseOperation, DxsAccElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
...@@ -643,6 +652,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -643,6 +652,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.c1_element_op_,
arg.dxs_in_element_op_, arg.dxs_in_element_op_,
arg.dxs_out_element_op_, arg.dxs_out_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
...@@ -701,6 +711,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -701,6 +711,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
C1ElementwiseOperation c1_element_op,
DxsInElementwiseOperation dxs_in_element_op, DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op) DxsAccElementwiseOperation dxs_out_element_op)
{ {
...@@ -720,6 +731,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -720,6 +731,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
c1_element_op,
dxs_in_element_op, dxs_in_element_op,
dxs_out_element_op}; dxs_out_element_op};
} }
...@@ -743,6 +755,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -743,6 +755,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
C1ElementwiseOperation c1_element_op,
DxsInElementwiseOperation dxs_in_element_op, DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op, DxsAccElementwiseOperation dxs_out_element_op,
index_t /* KBatch */ = 1) override index_t /* KBatch */ = 1) override
...@@ -763,6 +776,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle ...@@ -763,6 +776,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
c1_element_op,
dxs_in_element_op, dxs_in_element_op,
dxs_out_element_op); dxs_out_element_op);
} }
......
...@@ -52,6 +52,7 @@ template <typename DPtrsGlobal, ...@@ -52,6 +52,7 @@ template <typename DPtrsGlobal,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsInElementwiseOperation, typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation> typename DxsAccElementwiseOperation>
struct DeviceGemmBiasAddReduce : public BaseOperator struct DeviceGemmBiasAddReduce : public BaseOperator
...@@ -73,6 +74,7 @@ struct DeviceGemmBiasAddReduce : public BaseOperator ...@@ -73,6 +74,7 @@ struct DeviceGemmBiasAddReduce : public BaseOperator
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
C1ElementwiseOperation c1_element_op,
DxsInElementwiseOperation dxs_in_element_op, DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op, DxsAccElementwiseOperation dxs_out_element_op,
ck::index_t BatchCount = 1) = 0; ck::index_t BatchCount = 1) = 0;
...@@ -84,6 +86,7 @@ template <typename DPtrsGlobal, ...@@ -84,6 +86,7 @@ template <typename DPtrsGlobal,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsInElementwiseOperation, typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation> typename DxsAccElementwiseOperation>
using DeviceGemmBiasAddReducePtr = using DeviceGemmBiasAddReducePtr =
...@@ -91,6 +94,7 @@ using DeviceGemmBiasAddReducePtr = ...@@ -91,6 +94,7 @@ using DeviceGemmBiasAddReducePtr =
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
C1ElementwiseOperation,
DxsInElementwiseOperation, DxsInElementwiseOperation,
DxsAccElementwiseOperation>>; DxsAccElementwiseOperation>>;
......
...@@ -22,6 +22,7 @@ template <typename GridwiseGemm, ...@@ -22,6 +22,7 @@ template <typename GridwiseGemm,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsInElementwiseOperation, typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation, typename DxsAccElementwiseOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
...@@ -46,6 +47,7 @@ __global__ void ...@@ -46,6 +47,7 @@ __global__ void
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const C1ElementwiseOperation c1_element_op,
const DxsInElementwiseOperation dxs_in_element_op, const DxsInElementwiseOperation dxs_in_element_op,
const DxsAccElementwiseOperation dxs_out_element_op, const DxsAccElementwiseOperation dxs_out_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
...@@ -72,6 +74,7 @@ __global__ void ...@@ -72,6 +74,7 @@ __global__ void
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
c1_element_op,
dxs_in_element_op, dxs_in_element_op,
dxs_out_element_op, dxs_out_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
...@@ -91,6 +94,7 @@ __global__ void ...@@ -91,6 +94,7 @@ __global__ void
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = c1_element_op;
ignore = dxs_in_element_op; ignore = dxs_in_element_op;
ignore = dxs_out_element_op; ignore = dxs_out_element_op;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
...@@ -114,6 +118,7 @@ template <typename FloatAB, ...@@ -114,6 +118,7 @@ template <typename FloatAB,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename DxsReduceOperation, typename DxsReduceOperation,
typename DxsInElementwiseOperation, typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation, typename DxsAccElementwiseOperation,
...@@ -359,6 +364,7 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -359,6 +364,7 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const C1ElementwiseOperation& c1_element_op,
const DxsInElementwiseOperation& dxs_in_element_op, const DxsInElementwiseOperation& dxs_in_element_op,
const DxsAccElementwiseOperation& dxs_out_element_op, const DxsAccElementwiseOperation& dxs_out_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
...@@ -869,6 +875,7 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -869,6 +875,7 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
c01_thread_buf); c01_thread_buf);
// c = activation(c + bias)
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) { [&](auto i) {
FloatReduceAcc out; FloatReduceAcc out;
...@@ -883,8 +890,12 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -883,8 +890,12 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
c01_thread_buf); c01_thread_buf);
// c = c + c1_functior(c1)
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) { c_reduce_thread_buf(i) += c01_thread_buf(i); }); [&](auto i) {
c1_element_op(c01_thread_buf(i), c01_thread_buf(i));
c_reduce_thread_buf(i) += c01_thread_buf(i);
});
c_reduce_thread_copy_vgpr_to_global.Run( c_reduce_thread_copy_vgpr_to_global.Run(
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
......
...@@ -30,6 +30,7 @@ using DeviceGemmBiasAddReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmB ...@@ -30,6 +30,7 @@ using DeviceGemmBiasAddReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmB
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
DInElementOps, DInElementOps,
DOutElementOps>; DOutElementOps>;
...@@ -141,6 +142,7 @@ bool profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -141,6 +142,7 @@ bool profile_gemm_bias_add_reduce_impl(int do_verification,
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
using C1ElementOp = PassThrough;
using D0ReduceOp = ck::reduce::Add<float>; using D0ReduceOp = ck::reduce::Add<float>;
using D1ReduceOp = ck::reduce::Add<float>; using D1ReduceOp = ck::reduce::Add<float>;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryIdentic<float, float, true>; using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryIdentic<float, float, true>;
...@@ -154,6 +156,7 @@ bool profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -154,6 +156,7 @@ bool profile_gemm_bias_add_reduce_impl(int do_verification,
const auto a_element_op = AElementOp{}; const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{}; const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{}; const auto c_element_op = CElementOp{};
const auto c1_element_op = C1ElementOp{};
const auto d0_reduce_op = D0ReduceOp{}; const auto d0_reduce_op = D0ReduceOp{};
const auto d1_reduce_op = D1ReduceOp{}; const auto d1_reduce_op = D1ReduceOp{};
...@@ -183,8 +186,11 @@ bool profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -183,8 +186,11 @@ bool profile_gemm_bias_add_reduce_impl(int do_verification,
{ {
float acc = float acc =
static_cast<float>(c_m_n_host_result(m, n)) + static_cast<float>(bias_n(n)); static_cast<float>(c_m_n_host_result(m, n)) + static_cast<float>(bias_n(n));
float c1 = c1_m_n(m, n);
c_element_op(acc, acc); c_element_op(acc, acc);
acc += static_cast<float>(c1_m_n(m, n)); c1_element_op(c1, c1);
acc += static_cast<float>(c1);
c_m_n_host_result(m, n) = static_cast<CDataType>(acc); c_m_n_host_result(m, n) = static_cast<CDataType>(acc);
} }
...@@ -299,6 +305,7 @@ bool profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -299,6 +305,7 @@ bool profile_gemm_bias_add_reduce_impl(int do_verification,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
c1_element_op,
dxs_in_element_op, dxs_in_element_op,
dxs_out_element_op); dxs_out_element_op);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment