Commit 66d93ae5 authored by rocking's avatar rocking
Browse files

Rename Reduce -> R

parent 63914743
......@@ -22,7 +22,7 @@ using BDataType = F16;
using BiasDataType = F32;
using CDataType = F16;
using D0DataType = F16;
using ReduceDataType = F32;
using RDataType = F32;
using GammaDataType = F16;
using BetaDataType = F16;
using LayerNormOutDataType = F16;
......@@ -172,8 +172,8 @@ int main()
const auto normalize_ptrs =
ck::tensor_operation::device::instance::get_device_normalize_from_mean_meansquare_instances<
CDataType,
ReduceDataType,
ReduceDataType,
RDataType,
RDataType,
GammaDataType,
BetaDataType,
LayerNormOutDataType>();
......@@ -203,8 +203,8 @@ int main()
SimpleDeviceMem c_device_buf(sizeof(CDataType) * f_matrix_space_size(M, N, StrideC, CLayout{}));
SimpleDeviceMem d0_device_buf(sizeof(D0DataType) *
f_matrix_space_size(M, N, StrideD0, CLayout{}));
SimpleDeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * M);
SimpleDeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) * M);
SimpleDeviceMem reduceMean_device_buf(sizeof(RDataType) * M);
SimpleDeviceMem reduceMeanSquare_device_buf(sizeof(RDataType) * M);
SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * N);
SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * N);
SimpleDeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) * M * N);
......
......@@ -28,13 +28,13 @@ using F64 = double;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using ADataType = F16;
using BDataType = F16;
using CDataType = F16;
using GemmAccDataType = F32;
using ReduceAccDataType = F32;
using ReduceDataType = F64;
using ReducePtrsGlobal = ck::Tuple<ReduceDataType*>;
using ADataType = F16;
using BDataType = F16;
using CDataType = F16;
using GemmAccDataType = F32;
using RAccDataType = F32;
using RDataType = F64;
using RPtrsGlobal = ck::Tuple<RDataType*>;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
......@@ -53,11 +53,11 @@ static constexpr auto GemmSpecialization =
// clang-format off
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, ReduceAccDataType, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceElementOps, ReduceElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, RAccDataType, RPtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceElementOps, ReduceElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
......@@ -68,12 +68,12 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp
BElementOp,
CElementOp>;
template <typename ADataType, typename BDataType, typename CDataType, typename ReduceDataType>
template <typename ADataType, typename BDataType, typename CDataType, typename RDataType>
void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K)
{
std::size_t gemm_flop = std::size_t(2) * M * N * K;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M;
sizeof(CDataType) * M * N + sizeof(RDataType) * M;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time;
......@@ -148,11 +148,11 @@ int main(int argc, char* argv[])
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> reduce_m_host_result(
Tensor<RDataType> reduce_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> reduce_m_device_result(
Tensor<RDataType> reduce_m_device_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
......@@ -176,8 +176,7 @@ int main(int argc, char* argv[])
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem reduce_device_buf(sizeof(ReduceDataType) *
reduce_m_device_result.mDesc.GetElementSpace());
DeviceMem reduce_device_buf(sizeof(RDataType) * reduce_m_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
......@@ -220,7 +219,7 @@ int main(int argc, char* argv[])
// [CAUSION]: launch_and_time_kernel will not initialize D.
// If we evaluate kernel multiple time but without initialize D. Verification will fail
reduce_device_buf.SetValue(ck::NumericLimits<ReduceDataType>::Lowest());
reduce_device_buf.SetValue(ck::NumericLimits<RDataType>::Lowest());
invoker.Run(argument, StreamConfig{nullptr, false});
bool pass = true;
......@@ -242,12 +241,11 @@ int main(int argc, char* argv[])
for(int m = 0; m < M; ++m)
{
ReduceAccDataType reduce_acc = reduce_op.GetIdentityValue<ReduceAccDataType>();
RAccDataType reduce_acc = reduce_op.GetIdentityValue<RAccDataType>();
for(int n = 0; n < N; ++n)
{
ReduceAccDataType curr_val =
ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
RAccDataType curr_val = ck::type_convert<RAccDataType>(c_m_n_host_result(m, n));
reduce_op(reduce_acc, curr_val);
};
......@@ -268,7 +266,7 @@ int main(int argc, char* argv[])
{
float gemm_reduceMax_ave_time = invoker.Run(argument, StreamConfig{nullptr, true});
DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, ReduceDataType>(
DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, RDataType>(
gemm_reduceMax_ave_time, M, N, K);
}
......
......@@ -28,13 +28,13 @@ using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using ADataType = F16;
using BDataType = F16;
using CDataType = F16;
using GemmAccDataType = F32;
using ReduceAccDataType = F32;
using ReduceDataType = F32;
using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>;
using ADataType = F16;
using BDataType = F16;
using CDataType = F16;
using GemmAccDataType = F32;
using RAccDataType = F32;
using RDataType = F32;
using RPtrsGlobal = ck::Tuple<RDataType*, RDataType*>;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
......@@ -62,11 +62,11 @@ static constexpr auto GemmSpecialization =
// clang-format off
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceDData| A| B| C| Reduce| ReduceInEleOp| ReduceOutEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceDData| A| B| C| Reduce| ReduceInEleOp| ReduceOutEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, RPtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
......@@ -77,13 +77,13 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp
BElementOp,
CElementOp>;
template <typename ADataType, typename BDataType, typename CDataType, typename ReduceDataType>
template <typename ADataType, typename BDataType, typename CDataType, typename RDataType>
void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K)
{
std::size_t gemm_flop = std::size_t(2) * M * N * K;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M +
sizeof(ReduceDataType) * M;
sizeof(CDataType) * M * N + sizeof(RDataType) * M +
sizeof(RDataType) * M;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time;
......@@ -158,15 +158,15 @@ int main(int argc, char* argv[])
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> reduce0_m_host_result(
Tensor<RDataType> reduce0_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<ReduceDataType> reduce1_m_host_result(
Tensor<RDataType> reduce1_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> reduce0_m_device_result(
Tensor<RDataType> reduce0_m_device_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<ReduceDataType> reduce1_m_device_result(
Tensor<RDataType> reduce1_m_device_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
......@@ -191,9 +191,9 @@ int main(int argc, char* argv[])
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
DeviceMem reduce0_device_buf(sizeof(RDataType) *
reduce0_m_device_result.mDesc.GetElementSpace());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
DeviceMem reduce1_device_buf(sizeof(RDataType) *
reduce1_m_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_m_k.mData.data());
......@@ -269,13 +269,13 @@ int main(int argc, char* argv[])
for(int m = 0; m < M; ++m)
{
auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
auto reduce1_acc = reduce1_op.GetIdentityValue<ReduceAccDataType>();
auto reduce0_acc = reduce0_op.GetIdentityValue<RAccDataType>();
auto reduce1_acc = reduce1_op.GetIdentityValue<RAccDataType>();
for(int n = 0; n < N; ++n)
{
auto c_val = ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
ReduceAccDataType square_c_val;
auto c_val = ck::type_convert<RAccDataType>(c_m_n_host_result(m, n));
RAccDataType square_c_val;
square(square_c_val, c_val);
reduce0_op(reduce0_acc, c_val);
......@@ -284,8 +284,8 @@ int main(int argc, char* argv[])
div(reduce0_acc, reduce0_acc);
div(reduce1_acc, reduce1_acc);
reduce0_m_host_result(m) = ck::type_convert<ReduceDataType>(reduce0_acc);
reduce1_m_host_result(m) = ck::type_convert<ReduceDataType>(reduce1_acc);
reduce0_m_host_result(m) = ck::type_convert<RDataType>(reduce0_acc);
reduce1_m_host_result(m) = ck::type_convert<RDataType>(reduce1_acc);
}
pass = ck::utils::check_err(c_m_n_device_result.mData,
......@@ -307,7 +307,7 @@ int main(int argc, char* argv[])
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, true});
DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, ReduceDataType>(ave_time, M, N, K);
DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, RDataType>(ave_time, M, N, K);
}
return pass ? 0 : 1;
......
......@@ -27,12 +27,12 @@ using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using ADataType = F16;
using BDataType = F16;
using CDataType = F16;
using ReduceAccDataType = F32;
using ReduceDataType = F32;
using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>;
using ADataType = F16;
using BDataType = F16;
using CDataType = F16;
using RAccDataType = F32;
using RDataType = F32;
using RPtrsGlobal = ck::Tuple<RDataType*, RDataType*>;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
......@@ -59,11 +59,11 @@ static constexpr auto GemmSpecialization =
// clang-format off
using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatchedGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| R| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, RPtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
......@@ -143,16 +143,16 @@ int main(int argc, char* argv[])
Tensor<CDataType> c_g_m_n_host_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> d0_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
Tensor<RDataType> d0_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
Tensor<ReduceDataType> d1_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
Tensor<RDataType> d1_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
Tensor<CDataType> c_g_m_n_device_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> d0_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
Tensor<RDataType> d0_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
Tensor<ReduceDataType> d1_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
Tensor<RDataType> d1_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
......@@ -177,10 +177,8 @@ int main(int argc, char* argv[])
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace());
DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
d0_g_m_device_result.mDesc.GetElementSpace());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
d1_g_m_device_result.mDesc.GetElementSpace());
DeviceMem reduce0_device_buf(sizeof(RDataType) * d0_g_m_device_result.mDesc.GetElementSpace());
DeviceMem reduce1_device_buf(sizeof(RDataType) * d1_g_m_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data());
......@@ -269,15 +267,14 @@ int main(int argc, char* argv[])
{
for(int m = 0; m < M; ++m)
{
auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
auto reduce1_acc = reduce1_op.GetIdentityValue<ReduceAccDataType>();
auto reduce0_acc = reduce0_op.GetIdentityValue<RAccDataType>();
auto reduce1_acc = reduce1_op.GetIdentityValue<RAccDataType>();
for(int n = 0; n < N; ++n)
{
auto c_val =
ck::type_convert<ReduceAccDataType>(c_g_m_n_host_result(batch, m, n));
ReduceAccDataType d0_val;
ReduceAccDataType d1_val;
auto c_val = ck::type_convert<RAccDataType>(c_g_m_n_host_result(batch, m, n));
RAccDataType d0_val;
RAccDataType d1_val;
UnaryIdenticElementOp{}(d0_val, c_val);
UnarySquareElementOp{}(d1_val, c_val);
......@@ -285,8 +282,8 @@ int main(int argc, char* argv[])
reduce1_op(reduce1_acc, d1_val);
}
d0_g_m_host_result(batch, m) = ck::type_convert<ReduceDataType>(reduce0_acc);
d1_g_m_host_result(batch, m) = ck::type_convert<ReduceDataType>(reduce1_acc);
d0_g_m_host_result(batch, m) = ck::type_convert<RDataType>(reduce0_acc);
d1_g_m_host_result(batch, m) = ck::type_convert<RDataType>(reduce1_acc);
}
}
......
......@@ -34,9 +34,9 @@ using CDataType = F16;
using BiasDataType = F32;
using D0DataType = F16;
using GemmAccDataType = F32;
using ReduceAccDataType = F32;
using ReduceDataType = F32;
using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>;
using RAccDataType = F32;
using RDataType = F32;
using RPtrsGlobal = ck::Tuple<RDataType*, RDataType*>;
using GammaDataType = F16;
using BetaDataType = F16;
using LayerNormOutDataType = F16;
......@@ -69,11 +69,11 @@ static constexpr auto GemmSpecialization =
// clang-format off
using DeviceGemmBiasAddReduceInstance = ck::tensor_operation::device::DeviceGemmBiasAddReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| C1| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, D0ElementOp, ReduceOps,ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
//######| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| C1| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F16, F32, F32, F32, RPtrsGlobal, AElementOp, BElementOp, CElementOp, D0ElementOp, ReduceOps,ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
......@@ -89,8 +89,8 @@ using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize;
// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y
using DeviceNormalizeInstance =
ck::tensor_operation::device::Device5AryElementwise<CDataType,
ReduceDataType,
ReduceDataType,
RDataType,
RDataType,
GammaDataType,
BetaDataType,
LayerNormOutDataType,
......@@ -125,7 +125,7 @@ auto f_host_tensor_descriptor2d =
};
template <typename CDataType,
typename ReduceDataType,
typename RDataType,
typename AccDataType,
typename BiasDataType,
typename D0DataType,
......@@ -150,8 +150,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
int StrideC = N;
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<ReduceDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
Tensor<RDataType> mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<RDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
auto averageOpInst = UnaryDivElementOp{N};
auto ref_gemm = ReferenceGemmInstance{};
......@@ -196,8 +196,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
averageOpInst(mean_acc, mean_acc);
averageOpInst(square_mean_acc, square_mean_acc);
mean_m(m) = ck::type_convert<ReduceDataType>(mean_acc);
meanSquare_m(m) = ck::type_convert<ReduceDataType>(square_mean_acc);
mean_m(m) = ck::type_convert<RDataType>(mean_acc);
meanSquare_m(m) = ck::type_convert<RDataType>(square_mean_acc);
}
// LayerNorm
......@@ -213,7 +213,7 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
ck::type_convert<AccDataType>(meanSquare_m(m)),
ck::type_convert<AccDataType>(gamma_n(n)),
ck::type_convert<AccDataType>(beta_n(n)));
out_m_n(m, n) = ck::type_convert<ReduceDataType>(out_acc);
out_m_n(m, n) = ck::type_convert<RDataType>(out_acc);
}
}
}
......@@ -223,7 +223,7 @@ template <typename ADataType,
typename CDataType,
typename BiasDataType,
typename D0DataType,
typename ReduceDataType,
typename RDataType,
typename GammaDataType,
typename BetaDataType,
typename NormalizeDataType>
......@@ -232,11 +232,11 @@ void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M,
std::size_t gemm_flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(BiasDataType) * M * N +
sizeof(D0DataType) * M * N + sizeof(ReduceDataType) * M +
sizeof(ReduceDataType) * M;
sizeof(D0DataType) * M * N + sizeof(RDataType) * M +
sizeof(RDataType) * M;
std::size_t normalize_num_byte = sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M +
sizeof(ReduceDataType) * M + sizeof(GammaDataType) * N +
std::size_t normalize_num_byte = sizeof(CDataType) * M * N + sizeof(RDataType) * M +
sizeof(RDataType) * M + sizeof(GammaDataType) * N +
sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
......@@ -267,8 +267,8 @@ int main()
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<BiasDataType> bias_n(f_host_tensor_descriptor1d(N, 1));
Tensor<D0DataType> c1_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<ReduceDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1));
Tensor<RDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<RDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1));
Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
Tensor<LayerNormOutDataType> layerNorm_m_n(
......@@ -286,8 +286,8 @@ int main()
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace());
DeviceMem bias_device_buf(sizeof(BiasDataType) * bias_n.mDesc.GetElementSpace());
DeviceMem d0_device_buf(sizeof(D0DataType) * c1_m_n.mDesc.GetElementSpace());
DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * reduceMean_m.mDesc.GetElementSpace());
DeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) *
DeviceMem reduceMean_device_buf(sizeof(RDataType) * reduceMean_m.mDesc.GetElementSpace());
DeviceMem reduceMeanSquare_device_buf(sizeof(RDataType) *
reduceMeanSquare_m.mDesc.GetElementSpace());
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace());
......@@ -384,19 +384,19 @@ int main()
Tensor<LayerNormOutDataType> host_layerNorm_m_n(
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
host_gemm_layernorm<CDataType, ReduceDataType, ReduceAccDataType>(host_layerNorm_m_n,
a_m_k,
b_k_n,
bias_n,
c1_m_n,
gamma_n,
beta_n,
a_element_op,
b_element_op,
c_element_op,
d_element_op,
M,
N);
host_gemm_layernorm<CDataType, RDataType, RAccDataType>(host_layerNorm_m_n,
a_m_k,
b_k_n,
bias_n,
c1_m_n,
gamma_n,
beta_n,
a_element_op,
b_element_op,
c_element_op,
d_element_op,
M,
N);
layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data());
pass &= ck::utils::check_err(layerNorm_m_n.mData,
......@@ -421,7 +421,7 @@ int main()
CDataType,
BiasDataType,
D0DataType,
ReduceDataType,
RDataType,
GammaDataType,
BetaDataType,
LayerNormOutDataType>(
......
......@@ -32,9 +32,9 @@ using ADataType = F16;
using BDataType = F16;
using CDataType = F16;
using GemmAccDataType = F32;
using ReduceAccDataType = F32;
using ReduceDataType = F32;
using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>;
using RAccDataType = F32;
using RDataType = F32;
using RPtrsGlobal = ck::Tuple<RDataType*, RDataType*>;
using GammaDataType = F16;
using BetaDataType = F16;
using LayerNormOutDataType = F16;
......@@ -65,11 +65,11 @@ static constexpr auto GemmSpecialization =
// clang-format off
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps,ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, RPtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps,ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
......@@ -85,8 +85,8 @@ using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize;
// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y
using DeviceNormalizeInstance =
ck::tensor_operation::device::Device5AryElementwise<CDataType,
ReduceDataType,
ReduceDataType,
RDataType,
RDataType,
GammaDataType,
BetaDataType,
LayerNormOutDataType,
......@@ -121,7 +121,7 @@ auto f_host_tensor_descriptor2d =
};
template <typename CDataType,
typename ReduceDataType,
typename RDataType,
typename A_functor,
typename B_functor,
typename C_functor>
......@@ -140,8 +140,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
int StrideC = N;
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<ReduceDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
Tensor<RDataType> mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<RDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
auto averageOpInst = UnaryDivElementOp{N};
auto ref_gemm = ReferenceGemmInstance{};
......@@ -156,13 +156,13 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
auto reduceSumOpInst = ReduceSumOp{};
for(int m = 0; m < M; ++m)
{
auto mean_acc = reduceSumOpInst.GetIdentityValue<ReduceAccDataType>();
auto square_mean_acc = reduceSumOpInst.GetIdentityValue<ReduceAccDataType>();
auto mean_acc = reduceSumOpInst.GetIdentityValue<RAccDataType>();
auto square_mean_acc = reduceSumOpInst.GetIdentityValue<RAccDataType>();
for(int n = 0; n < N; ++n)
{
auto c_val = ck::type_convert<ReduceAccDataType>(c_m_n(m, n));
auto square_c_val = reduceSumOpInst.GetIdentityValue<ReduceAccDataType>();
auto c_val = ck::type_convert<RAccDataType>(c_m_n(m, n));
auto square_c_val = reduceSumOpInst.GetIdentityValue<RAccDataType>();
UnarySquareElementOp{}(square_c_val, c_val);
......@@ -172,8 +172,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
averageOpInst(mean_acc, mean_acc);
averageOpInst(square_mean_acc, square_mean_acc);
mean_m(m) = ck::type_convert<ReduceDataType>(mean_acc);
meanSquare_m(m) = ck::type_convert<ReduceDataType>(square_mean_acc);
mean_m(m) = ck::type_convert<RDataType>(mean_acc);
meanSquare_m(m) = ck::type_convert<RDataType>(square_mean_acc);
}
// LayerNorm
......@@ -197,7 +197,7 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ReduceDataType,
typename RDataType,
typename GammaDataType,
typename BetaDataType,
typename NormalizeDataType>
......@@ -205,11 +205,11 @@ void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M,
{
std::size_t gemm_flop = std::size_t(2) * M * N * K;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M +
sizeof(ReduceDataType) * M;
sizeof(CDataType) * M * N + sizeof(RDataType) * M +
sizeof(RDataType) * M;
std::size_t normalize_num_btye = sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M +
sizeof(ReduceDataType) * M + sizeof(GammaDataType) * N +
std::size_t normalize_num_btye = sizeof(CDataType) * M * N + sizeof(RDataType) * M +
sizeof(RDataType) * M + sizeof(GammaDataType) * N +
sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
......@@ -237,8 +237,8 @@ int main()
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<ReduceDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<ReduceDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1));
Tensor<RDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<RDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1));
Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
Tensor<LayerNormOutDataType> layerNorm_m_n(
......@@ -252,8 +252,8 @@ int main()
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace());
DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * reduceMean_m.mDesc.GetElementSpace());
DeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) *
DeviceMem reduceMean_device_buf(sizeof(RDataType) * reduceMean_m.mDesc.GetElementSpace());
DeviceMem reduceMeanSquare_device_buf(sizeof(RDataType) *
reduceMeanSquare_m.mDesc.GetElementSpace());
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace());
......@@ -347,16 +347,16 @@ int main()
Tensor<LayerNormOutDataType> host_layerNorm_m_n(
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
host_gemm_layernorm<CDataType, ReduceDataType>(host_layerNorm_m_n,
a_m_k,
b_k_n,
gamma_n,
beta_n,
a_element_op,
b_element_op,
c_element_op,
M,
N);
host_gemm_layernorm<CDataType, RDataType>(host_layerNorm_m_n,
a_m_k,
b_k_n,
gamma_n,
beta_n,
a_element_op,
b_element_op,
c_element_op,
M,
N);
layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data());
pass &= ck::utils::check_err(layerNorm_m_n.mData,
......@@ -379,7 +379,7 @@ int main()
DumpGemmLayerNormPerf<ADataType,
BDataType,
CDataType,
ReduceDataType,
RDataType,
GammaDataType,
BetaDataType,
LayerNormOutDataType>(
......
......@@ -33,16 +33,16 @@ template <typename ALayout,
typename D0DataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename ReduceAccDataType,
typename ReducePtrsGlobal,
typename RAccDataType,
typename RPtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename D0ElementwiseOperation,
typename ReduceOperations,
typename ReduceInElementwiseOperations,
typename ReduceAccElementwiseOperations,
typename ReduceGlobalMemoryDataOperation,
typename RInElementwiseOperations,
typename RAccElementwiseOperations,
typename RGlobalMemoryDataOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
......@@ -390,17 +390,17 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO
CDataType,
BiasDataType,
D0DataType,
ReduceAccDataType,
ReducePtrsGlobal,
RAccDataType,
RPtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D0ElementwiseOperation,
ReduceOperations,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
RInElementwiseOperations,
RAccElementwiseOperations,
InMemoryDataOperationEnum::Set,
ReduceGlobalMemoryDataOperation,
RGlobalMemoryDataOperation,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
......@@ -451,7 +451,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO
CDataType* p_c_grid,
const BiasDataType* p_bias_grid,
const D0DataType* p_d0_grid,
ReducePtrsGlobal p_reduces_grid,
RPtrsGlobal p_rs_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
......@@ -463,14 +463,14 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
D0ElementwiseOperation d0_element_op,
ReduceInElementwiseOperations reduce_in_element_ops,
ReduceAccElementwiseOperations reduce_out_element_ops)
RInElementwiseOperations reduce_in_element_ops,
RAccElementwiseOperations reduce_out_element_ops)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
p_bias_grid_{p_bias_grid},
p_d0_grid_{p_d0_grid},
p_reduces_grid_{p_reduces_grid},
p_rs_grid_{p_rs_grid},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
......@@ -507,7 +507,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO
c1_grid_desc_m_n_);
reduce_grid_desc_mblock_mperblock_ =
GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_);
GridwiseGemm::MakeRGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_);
}
}
......@@ -517,7 +517,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO
CDataType* p_c_grid_;
const BiasDataType* p_bias_grid_;
const D0DataType* p_d0_grid_;
ReducePtrsGlobal p_reduces_grid_;
RPtrsGlobal p_rs_grid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
......@@ -530,15 +530,14 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO
c0_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock_;
typename GridwiseGemm::RGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
D0ElementwiseOperation d0_element_op_;
ReduceInElementwiseOperations reduce_in_element_ops_;
ReduceAccElementwiseOperations reduce_out_element_ops_;
RInElementwiseOperations reduce_in_element_ops_;
RAccElementwiseOperations reduce_out_element_ops_;
};
// Invoker
......@@ -571,19 +570,19 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO
CDataType,
BiasDataType,
D0DataType,
ReducePtrsGlobal,
RPtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D0ElementwiseOperation,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
RInElementwiseOperations,
RAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::RGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
true>;
......@@ -598,7 +597,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO
arg.p_c_grid_,
arg.p_bias_grid_,
arg.p_d0_grid_,
arg.p_reduces_grid_,
arg.p_rs_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
......@@ -621,19 +620,19 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO
CDataType,
BiasDataType,
D0DataType,
ReducePtrsGlobal,
RPtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
D0ElementwiseOperation,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
RInElementwiseOperations,
RAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::RGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
false>;
......@@ -648,7 +647,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO
arg.p_c_grid_,
arg.p_bias_grid_,
arg.p_d0_grid_,
arg.p_reduces_grid_,
arg.p_rs_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
......@@ -701,7 +700,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO
const void* p_bias,
std::array<const void*, 1> p_ds,
void* p_c,
std::array<void*, NumReduce> p_reduces,
std::array<void*, NumReduce> p_rs,
ck::index_t M,
ck::index_t N,
ck::index_t K,
......@@ -714,24 +713,24 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO
std::array<void*, NumReduce> reduce_in_element_op,
std::array<void*, NumReduce> reduce_out_element_op)
{
ReducePtrsGlobal reduce_tuple = generate_tuple(
RPtrsGlobal reduce_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
auto tmp = RPtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
return static_cast<T*>(p_rs[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
RInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
auto tmp = RInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
RAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
auto tmp = RAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
......@@ -776,7 +775,7 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO
const void* p_bias,
std::array<const void*, 1> p_ds,
void* p_c,
std::array<void*, NumReduce> p_reduces,
std::array<void*, NumReduce> p_rs,
ck::index_t M,
ck::index_t N,
ck::index_t K,
......@@ -790,24 +789,24 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO
std::array<void*, NumReduce> reduce_out_element_op,
index_t /* KBatch */ = 1) override
{
ReducePtrsGlobal reduce_tuple = generate_tuple(
RPtrsGlobal reduce_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
auto tmp = RPtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
return static_cast<T*>(p_rs[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
RInElementwiseOperations reduce_in_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
auto tmp = RInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
RAccElementwiseOperations reduce_out_element_ops = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
auto tmp = RAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
},
......
......@@ -21,7 +21,7 @@ struct DeviceGemmReduce : public BaseOperator
const void* p_bias,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
std::array<void*, NumReduce> p_reduces,
std::array<void*, NumReduce> p_rs,
ck::index_t M,
ck::index_t N,
ck::index_t K,
......@@ -31,8 +31,8 @@ struct DeviceGemmReduce : public BaseOperator
std::array<ck::index_t, NumDTensor> StrideDs,
std::array<void*, 3> gemm_element_ops,
std::array<void*, NumDTensor> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_ops,
std::array<void*, NumReduce> reduce_out_element_ops,
std::array<void*, NumReduce> r_in_element_ops,
std::array<void*, NumReduce> r_out_element_ops,
ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
......
......@@ -31,15 +31,15 @@ template <typename ALayout,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename ReduceAccDataType,
typename ReducePtrsGlobal,
typename RAccDataType,
typename RPtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename ReduceOperations,
typename ReduceInElementwiseOperations,
typename ReduceAccElementwiseOperations,
typename ReduceGlobalMemoryDataOperation,
typename RInElementwiseOperations,
typename RAccElementwiseOperations,
typename RGlobalMemoryDataOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
......@@ -347,9 +347,9 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
}
// assume Reduce is packed tensor
static auto MakeReduceGridDescriptor_M(index_t MRaw)
static auto MakeRGridDescriptor_M(index_t MRaw)
{
const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
const auto r_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = M - MRaw;
......@@ -360,7 +360,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M
return transform_tensor_descriptor(d_grid_desc_mraw,
return transform_tensor_descriptor(r_grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
......@@ -368,14 +368,14 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
else
{
// not pad M
return d_grid_desc_mraw;
return r_grid_desc_mraw;
}
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1));
using RGridDesc_M = decltype(MakeRGridDescriptor_M(1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
......@@ -383,20 +383,20 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
GemmAccDataType,
CShuffleDataType,
CDataType,
ReduceAccDataType,
ReducePtrsGlobal,
RAccDataType,
RPtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ReduceOperations,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
RInElementwiseOperations,
RAccElementwiseOperations,
InMemoryDataOperationEnum::Set,
ReduceGlobalMemoryDataOperation,
RGlobalMemoryDataOperation,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
ReduceGridDesc_M,
RGridDesc_M,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
......@@ -439,7 +439,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
ReducePtrsGlobal p_reduces_grid,
RPtrsGlobal p_rs_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
......@@ -449,24 +449,24 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ReduceInElementwiseOperations reduce_in_element_ops,
ReduceAccElementwiseOperations reduce_out_element_ops)
RInElementwiseOperations r_in_element_ops,
RAccElementwiseOperations r_out_element_ops)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
p_reduces_grid_{p_reduces_grid},
p_rs_grid_{p_rs_grid},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)},
r_grid_desc_m_{DeviceOp::MakeRGridDescriptor_M(MRaw)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
reduce_grid_desc_mblock_mperblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op},
reduce_in_element_ops_{reduce_in_element_ops},
reduce_out_element_ops_{reduce_out_element_ops}
r_in_element_ops_{r_in_element_ops},
r_out_element_ops_{r_out_element_ops}
{
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
......@@ -478,7 +478,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
c_grid_desc_m_n_);
reduce_grid_desc_mblock_mperblock_ =
GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_);
GridwiseGemm::MakeRGridDescriptor_MBlock_MPerBlock(r_grid_desc_m_);
}
}
......@@ -486,21 +486,20 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
ReducePtrsGlobal p_reduces_grid_;
RPtrsGlobal p_rs_grid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
ReduceGridDesc_M reduce_grid_desc_m_;
RGridDesc_M r_grid_desc_m_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock_;
typename GridwiseGemm::RGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
ReduceInElementwiseOperations reduce_in_element_ops_;
ReduceAccElementwiseOperations reduce_out_element_ops_;
RInElementwiseOperations r_in_element_ops_;
RAccElementwiseOperations r_out_element_ops_;
};
// Invoker
......@@ -525,7 +524,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0) << "}"
std::cout << "arg.r_grid_desc_m_{ " << arg.r_grid_desc_m_.GetLength(I0) << "}"
<< std::endl;
}
#endif
......@@ -551,16 +550,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
ReducePtrsGlobal,
RPtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
RInElementwiseOperations,
RAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::RGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
true>;
......@@ -573,12 +572,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_reduces_grid_,
arg.p_rs_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.reduce_in_element_ops_,
arg.reduce_out_element_ops_,
arg.r_in_element_ops_,
arg.r_out_element_ops_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
......@@ -591,16 +590,16 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
ReducePtrsGlobal,
RPtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ReduceInElementwiseOperations,
ReduceAccElementwiseOperations,
RInElementwiseOperations,
RAccElementwiseOperations,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::RGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
false>;
......@@ -613,12 +612,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_reduces_grid_,
arg.p_rs_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.reduce_in_element_ops_,
arg.reduce_out_element_ops_,
arg.r_in_element_ops_,
arg.r_out_element_ops_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
......@@ -663,7 +662,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
const void* p_bias,
std::array<const void*, 0> p_ds,
void* p_c,
std::array<void*, NumReduce> p_reduces,
std::array<void*, NumReduce> p_rs,
ck::index_t M,
ck::index_t N,
ck::index_t K,
......@@ -673,34 +672,34 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
std::array<ck::index_t, 0> StrideDs,
std::array<void*, 3> gemm_element_ops,
std::array<void*, 0> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_op,
std::array<void*, NumReduce> reduce_out_element_op)
std::array<void*, NumReduce> r_in_element_ops,
std::array<void*, NumReduce> r_out_element_ops)
{
(void)p_bias;
(void)p_ds;
(void)StrideDs;
(void)d_element_ops;
ReducePtrsGlobal reduce_tuple = generate_tuple(
RPtrsGlobal reduce_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
auto tmp = RPtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
return static_cast<T*>(p_rs[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
RInElementwiseOperations r_in_element_op_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
auto tmp = RInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
return *(static_cast<T*>(r_in_element_ops[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
RAccElementwiseOperations r_out_element_op_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
auto tmp = RAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
return *(static_cast<T*>(r_out_element_ops[I]));
},
Number<NumReduce>{});
......@@ -724,8 +723,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
a_element_op,
b_element_op,
c_element_op,
reduce_in_element_ops,
reduce_out_element_ops};
r_in_element_op_tuple,
r_out_element_op_tuple};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -737,7 +736,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
const void* p_bias,
std::array<const void*, 0> p_ds,
void* p_c,
std::array<void*, NumReduce> p_reduces,
std::array<void*, NumReduce> p_rs,
ck::index_t M,
ck::index_t N,
ck::index_t K,
......@@ -747,8 +746,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
std::array<ck::index_t, 0> StrideDs,
std::array<void*, 3> gemm_element_ops,
std::array<void*, 0> d_element_ops,
std::array<void*, NumReduce> reduce_in_element_op,
std::array<void*, NumReduce> reduce_out_element_op,
std::array<void*, NumReduce> r_in_element_ops,
std::array<void*, NumReduce> r_out_element_ops,
ck::index_t = 1) override
{
(void)p_bias;
......@@ -756,26 +755,26 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
(void)StrideDs;
(void)d_element_ops;
ReducePtrsGlobal reduce_tuple = generate_tuple(
RPtrsGlobal reduce_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReducePtrsGlobal{}[I];
auto tmp = RPtrsGlobal{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return static_cast<T*>(p_reduces[I]);
return static_cast<T*>(p_rs[I]);
},
Number<NumReduce>{});
ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple(
RInElementwiseOperations r_in_element_op_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReduceInElementwiseOperations{}[I];
auto tmp = RInElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_in_element_op[I]));
return *(static_cast<T*>(r_in_element_ops[I]));
},
Number<NumReduce>{});
ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple(
RAccElementwiseOperations r_out_element_op_tuple = generate_tuple(
[&](auto I) {
auto tmp = ReduceAccElementwiseOperations{}[I];
auto tmp = RAccElementwiseOperations{}[I];
using T = remove_pointer_t<decltype(tmp)>;
return *(static_cast<T*>(reduce_out_element_op[I]));
return *(static_cast<T*>(r_out_element_ops[I]));
},
Number<NumReduce>{});
......@@ -799,8 +798,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
a_element_op,
b_element_op,
c_element_op,
reduce_in_element_ops,
reduce_out_element_ops);
r_in_element_op_tuple,
r_out_element_op_tuple);
}
// polymorphic
......
......@@ -38,7 +38,7 @@ template <typename ALayout,
typename C0DataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename ReduceAccDataType,
typename RAccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
......@@ -385,7 +385,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
CShuffleDataType,
CDataType,
C0DataType,
ReduceAccDataType,
RAccDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
......
......@@ -23,19 +23,19 @@ template <typename GridwiseGemm,
typename FloatC,
typename FloatC0,
typename FloatC1,
typename ReducePtrsGlobal,
typename RPtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename ReduceInElementwiseOperations,
typename ReduceAccElementwiseOperations,
typename RInElementwiseOperations,
typename RAccElementwiseOperations,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename ReduceGridDescriptor_MBlock_MPerBlock,
typename RGridDescriptor_MBlock_MPerBlock,
typename Block2CTileMap,
bool HasMainKBlockLoop>
__global__ void
......@@ -48,13 +48,13 @@ __global__ void
FloatC* __restrict__ p_c_grid,
const FloatC0* __restrict__ p_bias_grid,
const FloatC1* __restrict__ p_d0_grid,
ReducePtrsGlobal p_reduces_grid,
RPtrsGlobal p_reduces_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const C1ElementwiseOperation c1_element_op,
const ReduceInElementwiseOperations reduce_in_element_ops,
const ReduceAccElementwiseOperations reduce_out_element_ops,
const RInElementwiseOperations reduce_in_element_ops,
const RAccElementwiseOperations reduce_out_element_ops,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -63,7 +63,7 @@ __global__ void
c0_grid_desc_mblock_mperblock_nblock_nperblock,
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock,
const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
const RGridDescriptor_MBlock_MPerBlock r_grid_desc_mblock_mperblock,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
......@@ -87,7 +87,7 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock,
c0_grid_desc_mblock_mperblock_nblock_nperblock,
c1_grid_desc_mblock_mperblock_nblock_nperblock,
reduce_grid_desc_mblock_mperblock,
r_grid_desc_mblock_mperblock,
block_2_ctile_map);
#else
ignore = p_a_grid;
......@@ -107,7 +107,7 @@ __global__ void
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = c0_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = reduce_grid_desc_mblock_mperblock;
ignore = r_grid_desc_mblock_mperblock;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......@@ -119,16 +119,16 @@ template <typename FloatAB,
typename FloatC0,
typename FloatC1,
typename FloatReduceAcc,
typename ReducePtrsGlobal,
typename RPtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename C1ElementwiseOperation,
typename ReduceOperations,
typename ReduceInElementwiseOperations,
typename ReduceAccElementwiseOperations,
typename RInElementwiseOperations,
typename RAccElementwiseOperations,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename ReduceGlobalMemoryDataOperation,
typename RGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N,
......@@ -321,18 +321,18 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
__host__ __device__ static constexpr auto
MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M& d_grid_desc_m)
MakeRGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M& d_grid_desc_m)
{
const auto M = d_grid_desc_m.GetLength(I0);
const auto MBlock = M / MPerBlock;
const auto reduce_grid_desc_mblock_mperblock = transform_tensor_descriptor(
const auto r_grid_desc_mblock_mperblock = transform_tensor_descriptor(
d_grid_desc_m,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{}))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
return reduce_grid_desc_mblock_mperblock;
return r_grid_desc_mblock_mperblock;
}
// return block_id to C matrix tile idx (m0, n0) mapping
......@@ -352,37 +352,36 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C1GridDesc_M_N{}))>;
using ReduceGridDescriptor_MBlock_MPerBlock =
remove_cvref_t<decltype(MakeReduceGridDescriptor_MBlock_MPerBlock(ReduceGridDesc_M{}))>;
using RGridDescriptor_MBlock_MPerBlock =
remove_cvref_t<decltype(MakeRGridDescriptor_MBlock_MPerBlock(ReduceGridDesc_M{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
template <bool HasMainKBlockLoop, typename Block2CTileMap>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const FloatC0* __restrict__ p_bias_grid,
const FloatC1* __restrict__ p_d0_grid,
ReducePtrsGlobal p_reduces_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const C1ElementwiseOperation& c1_element_op,
const ReduceInElementwiseOperations& reduce_in_element_ops,
const ReduceAccElementwiseOperations& reduce_out_element_ops,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c0_grid_desc_mblock_mperblock_nblock_nperblock,
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c1_grid_desc_mblock_mperblock_nblock_nperblock,
const ReduceGridDescriptor_MBlock_MPerBlock& reduce_grid_desc_mblock_mperblock,
const Block2CTileMap& block_2_ctile_map)
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const FloatC0* __restrict__ p_bias_grid,
const FloatC1* __restrict__ p_d0_grid,
RPtrsGlobal p_reduces_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const C1ElementwiseOperation& c1_element_op,
const RInElementwiseOperations& reduce_in_element_ops,
const RAccElementwiseOperations& reduce_out_element_ops,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c0_grid_desc_mblock_mperblock_nblock_nperblock,
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c1_grid_desc_mblock_mperblock_nblock_nperblock,
const RGridDescriptor_MBlock_MPerBlock& r_grid_desc_mblock_mperblock,
const Block2CTileMap& block_2_ctile_map)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
......@@ -769,15 +768,15 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
FloatReduceAcc,
remove_pointer_t<decltype(p_reduce_grid)>,
decltype(reduce_thread_desc_mblock_mperblock),
decltype(reduce_grid_desc_mblock_mperblock),
decltype(r_grid_desc_mblock_mperblock),
decltype(reduce_acc_element_op),
Sequence<1, mreduce_per_thread>,
Sequence<0, 1>,
1,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
ReduceGlobalMemoryDataOperation::At(I),
RGlobalMemoryDataOperation::At(I),
1,
false>{reduce_grid_desc_mblock_mperblock,
false>{r_grid_desc_mblock_mperblock,
make_multi_index(block_work_idx[I0], // mblock
c_reduce_thread_data_idx_begin[I0]), // mperblock
reduce_acc_element_op};
......@@ -914,7 +913,7 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto& p_reduce_grid = p_reduces_grid[In];
auto reduce_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize());
p_reduce_grid, r_grid_desc_mblock_mperblock.GetElementSpaceSize());
auto reduce_thread_buf =
make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
......@@ -958,14 +957,14 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock,
make_tuple(I0, I0),
reduce_thread_buf,
reduce_grid_desc_mblock_mperblock,
r_grid_desc_mblock_mperblock,
reduce_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
reduce_grid_desc_mblock_mperblock,
r_grid_desc_mblock_mperblock,
make_tuple(c_global_step[I0], c_global_step[I1]));
}
});
......
......@@ -21,16 +21,16 @@ namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename ReducePtrsGlobal,
typename RPtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename ReduceInElementwiseOperations,
typename ReduceAccElementwiseOperations,
typename RInElementwiseOperations,
typename RAccElementwiseOperations,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename ReduceGridDescriptor_MBlock_MPerBlock,
typename RGridDescriptor_MBlock_MPerBlock,
typename Block2CTileMap,
bool HasMainKBlockLoop>
__global__ void
......@@ -41,17 +41,17 @@ __global__ void
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
ReducePtrsGlobal p_reduces_grid,
RPtrsGlobal p_rs_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const ReduceInElementwiseOperations reduce_in_element_ops,
const ReduceAccElementwiseOperations reduce_out_element_ops,
const RInElementwiseOperations r_in_element_ops,
const RAccElementwiseOperations r_out_element_ops,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
const RGridDescriptor_MBlock_MPerBlock r_grid_desc_mblock_mperblock,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
......@@ -60,32 +60,32 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_reduces_grid,
p_rs_grid,
p_shared,
a_element_op,
b_element_op,
c_element_op,
reduce_in_element_ops,
reduce_out_element_ops,
r_in_element_ops,
r_out_element_ops,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
reduce_grid_desc_mblock_mperblock,
r_grid_desc_mblock_mperblock,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = p_reduces_grid;
ignore = p_rs_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = reduce_in_element_ops;
ignore = reduce_out_element_ops;
ignore = r_in_element_ops;
ignore = r_out_element_ops;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = reduce_grid_desc_mblock_mperblock;
ignore = r_grid_desc_mblock_mperblock;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......@@ -94,20 +94,20 @@ template <typename FloatAB,
typename FloatGemmAcc,
typename FloatCShuffle,
typename FloatC,
typename FloatReduceAcc,
typename ReducePtrsGlobal,
typename FloatRAcc,
typename RPtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename ReduceOperations,
typename ReduceInElementwiseOperations,
typename ReduceAccElementwiseOperations,
typename RInElementwiseOperations,
typename RAccElementwiseOperations,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename ReduceGlobalMemoryDataOperation,
typename RGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N,
typename ReduceGridDesc_M,
typename RGridDesc_M,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
......@@ -293,18 +293,18 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
__host__ __device__ static constexpr auto
MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M& d_grid_desc_m)
MakeRGridDescriptor_MBlock_MPerBlock(const RGridDesc_M& d_grid_desc_m)
{
const auto M = d_grid_desc_m.GetLength(I0);
const auto MBlock = M / MPerBlock;
const auto reduce_grid_desc_mblock_mperblock = transform_tensor_descriptor(
const auto r_grid_desc_mblock_mperblock = transform_tensor_descriptor(
d_grid_desc_m,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{}))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
return reduce_grid_desc_mblock_mperblock;
return r_grid_desc_mblock_mperblock;
}
// return block_id to C matrix tile idx (m0, n0) mapping
......@@ -318,30 +318,29 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using ReduceGridDescriptor_MBlock_MPerBlock =
remove_cvref_t<decltype(MakeReduceGridDescriptor_MBlock_MPerBlock(ReduceGridDesc_M{}))>;
using RGridDescriptor_MBlock_MPerBlock =
remove_cvref_t<decltype(MakeRGridDescriptor_MBlock_MPerBlock(RGridDesc_M{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
template <bool HasMainKBlockLoop, typename Block2CTileMap>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
ReducePtrsGlobal p_reduces_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const ReduceInElementwiseOperations& reduce_in_element_ops,
const ReduceAccElementwiseOperations& reduce_out_element_ops,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ReduceGridDescriptor_MBlock_MPerBlock& reduce_grid_desc_mblock_mperblock,
const Block2CTileMap& block_2_ctile_map)
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
RPtrsGlobal p_rs_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const RInElementwiseOperations& r_in_element_ops,
const RAccElementwiseOperations& r_out_element_ops,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const RGridDescriptor_MBlock_MPerBlock& r_grid_desc_mblock_mperblock,
const Block2CTileMap& block_2_ctile_map)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
......@@ -715,7 +714,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto reduce_thread_desc_mblock_mperblock =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<mreduce_per_thread>{}));
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
auto c_reduce_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatRAcc>(
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
// reduce: threadwise copy from LDS to VGPR
......@@ -731,7 +730,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatCShuffle,
FloatReduceAcc,
FloatRAcc,
decltype(c_reduce_block_desc_mperblock_nperblock),
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(c_reduce_thread_lengths_mperblock_nperblock),
......@@ -743,27 +742,27 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple(
[&](auto I) {
auto p_reduce_grid = p_reduces_grid[I];
auto reduce_acc_element_op = reduce_out_element_ops[I];
auto p_reduce_grid = p_rs_grid[I];
auto reduce_acc_element_op = r_out_element_ops[I];
return ThreadwiseTensorSliceTransfer_v1r3<
FloatReduceAcc,
FloatRAcc,
remove_pointer_t<decltype(p_reduce_grid)>,
decltype(reduce_thread_desc_mblock_mperblock),
decltype(reduce_grid_desc_mblock_mperblock),
decltype(r_grid_desc_mblock_mperblock),
decltype(reduce_acc_element_op),
Sequence<1, mreduce_per_thread>,
Sequence<0, 1>,
1,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
ReduceGlobalMemoryDataOperation::At(I),
RGlobalMemoryDataOperation::At(I),
1,
false>{reduce_grid_desc_mblock_mperblock,
false>{r_grid_desc_mblock_mperblock,
make_multi_index(block_work_idx[I0], // mblock
c_reduce_thread_data_idx_begin[I0]), // mperblock
reduce_acc_element_op};
},
Number<p_reduces_grid.Size()>{});
Number<p_rs_grid.Size()>{});
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
......@@ -798,24 +797,24 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_tuple(I0, I0),
c_reduce_thread_buf);
static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) {
auto& p_reduce_grid = p_reduces_grid[In];
static_for<0, p_rs_grid.Size(), 1>{}([&](auto In) {
auto& p_reduce_grid = p_rs_grid[In];
auto reduce_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize());
p_reduce_grid, r_grid_desc_mblock_mperblock.GetElementSpaceSize());
auto reduce_thread_buf =
make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
make_static_buffer<AddressSpaceEnum::Vgpr, FloatRAcc>(
reduce_thread_desc_mperblock.GetElementSpaceSize());
auto& reduce_in_element_op = reduce_in_element_ops[In];
auto& reduce_in_element_op = r_in_element_ops[In];
auto& reduce_thread_copy_vgpr_to_global =
reduce_tuple_thread_copy_vgpr_to_global(In);
using ReduceOperation = remove_cvref_t<decltype(ReduceOperations{}[In])>;
using ThreadwiseReduce =
ThreadwiseReduction<FloatReduceAcc,
ThreadwiseReduction<FloatRAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(reduce_thread_desc_mperblock),
ReduceOperation,
......@@ -823,7 +822,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// Global write Gemm shuffle + reduction
const auto reduce_identityVal =
ReduceOperation::template GetIdentityValue<FloatReduceAcc>();
ReduceOperation::template GetIdentityValue<FloatRAcc>();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { reduce_thread_buf(I) = reduce_identityVal; });
......@@ -846,14 +845,14 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock,
make_tuple(I0, I0),
reduce_thread_buf,
reduce_grid_desc_mblock_mperblock,
r_grid_desc_mblock_mperblock,
reduce_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
reduce_grid_desc_mblock_mperblock,
r_grid_desc_mblock_mperblock,
make_tuple(c_global_step[I0], c_global_step[I1]));
}
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment