Commit 300ac4e4 authored by rocking's avatar rocking
Browse files

Implement gemm bias add reduction

parent 09a2b547
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_fp16 gemm_bias_relu_add_layernorm_xdl_fp16.cpp)
add_example_executable(example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp) add_example_executable(example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp)
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "device_5ary_elementwise.hpp"
#include "device_gemm_bias_add_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
#include "element_wise_reduce_operation.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using ADataType = F16;
using BDataType = F16;
using CDataType = F16;
using C0DataType = F32;
using C1DataType = F32;
using GemmAccDataType = F32;
using ReduceAccDataType = F32;
using DDataType = F32;
using DPtrsGlobal = ck::Tuple<DDataType*, DDataType*>;
using GammaDataType = F16;
using BetaDataType = F16;
using LayerNormOutDataType = F16;
using NormalizeComputeDataType = F32;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = ck::tensor_operation::element_wise::Relu;
using ReduceSumOp = ck::reduce::Add<ReduceAccDataType>;
using DxsReduceOp = ck::Tuple<ReduceSumOp, ReduceSumOp>;
using UnaryIdenticElementOp =
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, false>;
using UnaryDivElementOp =
ck::tensor_operation::element_wise::UnaryIdentic<ReduceAccDataType, ReduceAccDataType, true>;
using UnarySquareElementOp =
ck::tensor_operation::element_wise::UnarySquare<ReduceAccDataType, ReduceAccDataType, false>;
using DxsInElementOp = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOp = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
using DxsGlobalMemOp =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>;
static constexpr auto GemmSpecialization =
ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmBiasAddReduceInstance = ck::tensor_operation::device::DeviceGemmBiasAddReduce_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOp, DxsOutElementOp, DxsGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
GemmAccDataType,
AElementOp,
BElementOp,
PassThrough>;
using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize;
// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y
using DeviceNormalizeInstance =
ck::tensor_operation::device::Device5AryElementwise<CDataType,
DDataType,
DDataType,
GammaDataType,
BetaDataType,
LayerNormOutDataType,
NormalizeComputeDataType,
NormalizeFunctor,
2,
8,
8, // scalarPerVector: gemm_out
1, // scalarPerVector: reduce_mean
1, // scalarPerVector: reduce_mean_square
8, // scalarPerVector: Gamma
8, // scalarPerVector: Beta
8>; // scalarPerVector: LayerNorm_out
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({len}),
std::vector<std::size_t>({stride}));
};
auto f_host_tensor_descriptor2d =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
template <typename CDataType,
typename DDataType,
typename AccDataType,
typename C0DataType,
typename C1DataType,
typename A_functor,
typename B_functor,
typename C_functor>
void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
const Tensor<ADataType>& a_m_k,
const Tensor<ADataType>& b_k_n,
const Tensor<C0DataType>& bias_n,
const Tensor<C1DataType>& c1_m_n,
const Tensor<GammaDataType>& gamma_n,
const Tensor<GammaDataType>& beta_n,
A_functor a_element_op,
B_functor b_element_op,
C_functor c_element_op,
int M,
int N)
{
int StrideC = N;
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<DDataType> mean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<DDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
auto averageOpInst = UnaryDivElementOp{M};
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
for(int n = 0; n < N; ++n)
{
AccDataType acc =
static_cast<AccDataType>(c_m_n(m, n)) + static_cast<AccDataType>(bias_n(n));
c_element_op(acc, acc);
acc += static_cast<AccDataType>(c1_m_n(m, n));
c_m_n(m, n) = static_cast<CDataType>(acc);
}
// reduce_mean and reduce_square_mean
auto reduceSumOpInst = ReduceSumOp{};
for(int m = 0; m < M; ++m)
{
float mean_acc = reduceSumOpInst.GetReductionZeroVal();
float square_mean_acc = reduceSumOpInst.GetReductionZeroVal();
for(int n = 0; n < N; ++n)
{
AccDataType c_val = ck::type_convert<AccDataType>(c_m_n(m, n));
AccDataType square_c_val = 0;
UnarySquareElementOp{}(square_c_val, c_val);
reduceSumOpInst(mean_acc, c_val);
reduceSumOpInst(square_mean_acc, square_c_val);
}
averageOpInst(mean_acc, mean_acc);
averageOpInst(square_mean_acc, square_mean_acc);
mean_m(m) = ck::type_convert<DDataType>(mean_acc);
meanSquare_m(m) = ck::type_convert<DDataType>(square_mean_acc);
}
// LayerNorm
auto layerNormInst = NormalizeFunctor{};
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
float out_f32 = 0;
layerNormInst(out_f32, c_m_n(m, n), mean_m(m), meanSquare_m(m), gamma_n(n), beta_n(n));
out_m_n(m, n) = static_cast<DDataType>(out_f32);
}
}
}
template <typename ADataType,
typename BDataType,
typename CDataType,
typename C0DataType,
typename C1DataType,
typename DDataType,
typename GammaDataType,
typename BetaDataType,
typename NormalizeDataType>
void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M, int N, int K)
{
std::size_t gemm_flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N;
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(C0DataType) * M * N +
sizeof(C1DataType) * M * N + sizeof(DDataType) * M +
sizeof(DDataType) * M;
std::size_t normalize_num_byte = sizeof(CDataType) * M * N + sizeof(DDataType) * M +
sizeof(DDataType) * M + sizeof(GammaDataType) * N +
sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N;
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time;
float normalize_gb_per_sec = normalize_num_byte / 1.E6 / normalize_time;
std::cout << "gemm + reduce_mean + reduce_square_mean Perf: " << gemm_reduce_time << " ms, "
<< tflops << " TFlops, " << gemm_gb_per_sec << " GB/s, " << std::endl;
std::cout << "5-ary elementwise Perf: " << normalize_time << " ms, " << normalize_gb_per_sec
<< " GB/s, " << std::endl;
}
int main()
{
// GEMM shape
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 1024;
ck::index_t StrideA = 1024;
ck::index_t StrideB = 1024;
ck::index_t StrideC = 1024;
ck::index_t StrideC1 = 1024;
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<C0DataType> bias_n(f_host_tensor_descriptor1d(N, 1));
Tensor<C1DataType> c1_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<DDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1));
Tensor<DDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1));
Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
Tensor<LayerNormOutDataType> layerNorm_m_n(
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
bias_n.GenerateTensorValue(GeneratorTensor_3<C0DataType>{-1, 1});
c1_m_n.GenerateTensorValue(GeneratorTensor_3<C1DataType>{-5, 5});
gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1});
beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace());
DeviceMem bias_device_buf(sizeof(C0DataType) * bias_n.mDesc.GetElementSpace());
DeviceMem c1_device_buf(sizeof(C1DataType) * c1_m_n.mDesc.GetElementSpace());
DeviceMem reduceMean_device_buf(sizeof(DDataType) * reduceMean_m.mDesc.GetElementSpace());
DeviceMem reduceMeanSquare_device_buf(sizeof(DDataType) *
reduceMeanSquare_m.mDesc.GetElementSpace());
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace());
DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) *
layerNorm_m_n.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
bias_device_buf.ToDevice(bias_n.mData.data());
c1_device_buf.ToDevice(c1_m_n.mData.data());
gamma_device_buf.ToDevice(gamma_n.mData.data());
beta_device_buf.ToDevice(beta_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
auto dxs_global =
ck::make_tuple(static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer()));
auto dxs_in_element_op = DxsInElementOp{};
auto dxs_out_element_op = DxsOutElementOp{M, M};
// Prepare GEMM, reduce_mean, reduce_mean_square
auto gemmReduce = DeviceGemmBiasAddReduceInstance{};
auto gemmReduce_invoker = gemmReduce.MakeInvoker();
auto gemmReduce_argument =
gemmReduce.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<C0DataType*>(bias_device_buf.GetDeviceBuffer()),
static_cast<C1DataType*>(c1_device_buf.GetDeviceBuffer()),
dxs_global,
M,
N,
K,
StrideA,
StrideB,
StrideC,
StrideC1,
a_element_op,
b_element_op,
c_element_op,
dxs_in_element_op,
dxs_out_element_op);
if(!gemmReduce.IsSupportedArgument(gemmReduce_argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
reduceMean_device_buf.SetZero();
reduceMeanSquare_device_buf.SetZero();
// Prepare LayerNorm
auto normalize = DeviceNormalizeInstance{};
auto normalize_invoker = normalize.MakeInvoker();
auto normalize_argument = normalize.MakeArgument(
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer()),
static_cast<GammaDataType*>(gamma_device_buf.GetDeviceBuffer()),
static_cast<BetaDataType*>(beta_device_buf.GetDeviceBuffer()),
static_cast<LayerNormOutDataType*>(layerNorm_device_buf.GetDeviceBuffer()),
{M, N},
{StrideC, 1},
{1, 0},
{1, 0},
{0, 1},
{0, 1},
{StrideC, 1},
NormalizeFunctor{});
if(!normalize.IsSupportedArgument(normalize_argument))
{
throw std::runtime_error("The runtime parameters seems not supported by the "
"Device5AryElementwise instance, exiting!");
}
// run kernel
gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, false});
normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, false});
bool pass = true;
{
// verification
Tensor<LayerNormOutDataType> host_layerNorm_m_n(
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
host_gemm_layernorm<CDataType, DDataType, ReduceAccDataType>(host_layerNorm_m_n,
a_m_k,
b_k_n,
bias_n,
c1_m_n,
gamma_n,
beta_n,
a_element_op,
b_element_op,
c_element_op,
M,
N);
layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data());
pass &= ck::utils::check_err(layerNorm_m_n.mData,
host_layerNorm_m_n.mData,
"Error: Incorrect results layerNorm_m_n",
1e-2,
1e-2);
}
{
// evaluate kernel perf
bool time_kernel = true;
float gemm_reduce_mean_reduce_square_mean_ave_time =
gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel});
float normalize_ave_time =
normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, time_kernel});
if(time_kernel)
DumpGemmLayerNormPerf<ADataType,
BDataType,
CDataType,
C0DataType,
C1DataType,
DDataType,
GammaDataType,
BetaDataType,
LayerNormOutDataType>(
gemm_reduce_mean_reduce_square_mean_ave_time, normalize_ave_time, M, N, K);
}
return pass ? 0 : 1;
}
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_reduce_xdl_cshuffle_v1.hpp" #include "gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
namespace ck { namespace ck {
...@@ -23,6 +23,8 @@ template <typename ALayout, ...@@ -23,6 +23,8 @@ template <typename ALayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename C0DataType,
typename C1DataType,
typename GemmAccDataType, typename GemmAccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename ReduceAccDataType, typename ReduceAccDataType,
...@@ -68,14 +70,15 @@ template <typename ALayout, ...@@ -68,14 +70,15 @@ template <typename ALayout,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, struct DeviceGemmBiasAddReduce_Xdl_CShuffle
: public DeviceGemmBiasAddReduce<DPtrsGlobal,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DxsInElementwiseOperation, DxsInElementwiseOperation,
DxsAccElementwiseOperation> DxsAccElementwiseOperation>
{ {
using DeviceOp = DeviceGemmReduce_Xdl_CShuffle; using DeviceOp = DeviceGemmBiasAddReduce_Xdl_CShuffle;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -374,14 +377,18 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -374,14 +377,18 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using C0GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 0));
using C1GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1)); using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< using GridwiseGemm = GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
C0DataType,
C1DataType,
ReduceAccDataType, ReduceAccDataType,
DPtrsGlobal, DPtrsGlobal,
AElementwiseOperation, AElementwiseOperation,
...@@ -395,6 +402,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -395,6 +402,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
CGridDesc_M_N, CGridDesc_M_N,
C0GridDesc_M_N,
C1GridDesc_M_N,
DGridDesc_M, DGridDesc_M,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
...@@ -438,6 +447,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -438,6 +447,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
Argument(const ADataType* p_a_grid, Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid, const BDataType* p_b_grid,
CDataType* p_c_grid, CDataType* p_c_grid,
const C0DataType* p_c0_grid,
const C1DataType* p_c1_grid,
DPtrsGlobal p_ds_grid, DPtrsGlobal p_ds_grid,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
...@@ -445,6 +456,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -445,6 +456,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
index_t StrideC1,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
...@@ -453,12 +465,18 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -453,12 +465,18 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
p_c0_grid_{p_c0_grid},
p_c1_grid_{p_c1_grid},
p_ds_grid_{p_ds_grid}, p_ds_grid_{p_ds_grid},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
c0_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, 0)},
c1_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC1)},
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)}, d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{},
c0_grid_desc_mblock_mperblock_nblock_nperblock_{},
c1_grid_desc_mblock_mperblock_nblock_nperblock_{},
d_grid_desc_mblock_mperblock_{}, d_grid_desc_mblock_mperblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
...@@ -476,6 +494,14 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -476,6 +494,14 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_); c_grid_desc_m_n_);
c0_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c0_grid_desc_m_n_);
c1_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c1_grid_desc_m_n_);
d_grid_desc_mblock_mperblock_ = d_grid_desc_mblock_mperblock_ =
GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_); GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_);
} }
...@@ -485,13 +511,21 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -485,13 +511,21 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
const C0DataType* p_c0_grid_;
const C1DataType* p_c1_grid_;
DPtrsGlobal p_ds_grid_; DPtrsGlobal p_ds_grid_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
C0GridDesc_M_N c0_grid_desc_m_n_;
C1GridDesc_M_N c1_grid_desc_m_n_;
DGridDesc_M d_grid_desc_m_; DGridDesc_M d_grid_desc_m_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c0_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_; typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -508,26 +542,6 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -508,26 +542,6 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.d_grid_desc_m_{ " << arg.d_grid_desc_m_.GetLength(I0) << "}"
<< std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
...@@ -545,10 +559,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -545,10 +559,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
float elapsed_time = 0.0f; float elapsed_time = 0.0f;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1< const auto kernel = kernel_gemm_bias_add_reduce_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
C0DataType,
C1DataType,
DPtrsGlobal, DPtrsGlobal,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -558,6 +574,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -558,6 +574,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock, typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
true>; true>;
...@@ -571,6 +589,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -571,6 +589,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_c0_grid_,
arg.p_c1_grid_,
arg.p_ds_grid_, arg.p_ds_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
...@@ -580,15 +600,19 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -580,15 +600,19 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_, arg.d_grid_desc_mblock_mperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
else else
{ {
const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1< const auto kernel = kernel_gemm_bias_add_reduce_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
C0DataType,
C1DataType,
DPtrsGlobal, DPtrsGlobal,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -598,6 +622,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -598,6 +622,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock, typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
false>; false>;
...@@ -611,6 +637,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -611,6 +637,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_c0_grid_,
arg.p_c1_grid_,
arg.p_ds_grid_, arg.p_ds_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
...@@ -620,6 +648,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -620,6 +648,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.d_grid_desc_mblock_mperblock_, arg.d_grid_desc_mblock_mperblock_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
...@@ -658,6 +688,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -658,6 +688,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b, const BDataType* p_b,
CDataType* p_c, CDataType* p_c,
const C0DataType* p_c0,
const C1DataType* p_c1,
DPtrsGlobal p_dxs, DPtrsGlobal p_dxs,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
...@@ -665,6 +697,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -665,6 +697,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
index_t StrideC1,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
...@@ -674,6 +707,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -674,6 +707,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
return Argument{p_a, return Argument{p_a,
p_b, p_b,
p_c, p_c,
p_c0,
p_c1,
p_dxs, p_dxs,
MRaw, MRaw,
NRaw, NRaw,
...@@ -681,6 +716,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -681,6 +716,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
StrideC1,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
...@@ -694,6 +730,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -694,6 +730,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
void* p_c, void* p_c,
const void* p_c0,
const void* p_c1,
DPtrsGlobal p_dxs, DPtrsGlobal p_dxs,
index_t MRaw, index_t MRaw,
index_t NRaw, index_t NRaw,
...@@ -701,6 +739,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -701,6 +739,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
index_t StrideC1,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
...@@ -711,6 +750,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -711,6 +750,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
static_cast<const C0DataType*>(p_c0),
static_cast<const C1DataType*>(p_c1),
p_dxs, p_dxs,
MRaw, MRaw,
NRaw, NRaw,
...@@ -718,6 +759,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal, ...@@ -718,6 +759,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<DPtrsGlobal,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
StrideC1,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
......
...@@ -48,6 +48,52 @@ using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<DPtrsGlobal, ...@@ -48,6 +48,52 @@ using DeviceGemmReducePtr = std::unique_ptr<DeviceGemmReduce<DPtrsGlobal,
DxsInElementwiseOperation, DxsInElementwiseOperation,
DxsAccElementwiseOperation>>; DxsAccElementwiseOperation>>;
template <typename DPtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation>
struct DeviceGemmBiasAddReduce : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const void* p_c0,
const void* p_c1,
DPtrsGlobal p_dxs,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t StrideC1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
DxsInElementwiseOperation dxs_in_element_op,
DxsAccElementwiseOperation dxs_out_element_op,
ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename DPtrsGlobal,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename DxsInElementwiseOperation,
typename DxsAccElementwiseOperation>
using DeviceGemmBiasAddReducePtr =
std::unique_ptr<DeviceGemmBiasAddReduce<DPtrsGlobal,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DxsInElementwiseOperation,
DxsAccElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -143,6 +143,21 @@ struct AddHardswishAdd ...@@ -143,6 +143,21 @@ struct AddHardswishAdd
} }
}; };
struct Relu
{
__host__ __device__ void operator()(float& y, const float& x) const { y = x > 0 ? x : 0; }
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x > 0 ? x : 0; }
__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { y = x > 0 ? x : 0; }
__host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x > 0 ? x : 0; }
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x > 0 ? x : 0; }
__host__ __device__ void operator()(double& y, const double& x) const { y = x > 0 ? x : 0; }
};
struct Normalize struct Normalize
{ {
Normalize(float epsilon = 1e-4) : epsilon_(epsilon) {} Normalize(float epsilon = 1e-4) : epsilon_(epsilon) {}
......
...@@ -16,6 +16,8 @@ namespace ck { ...@@ -16,6 +16,8 @@ namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename FloatC0,
typename FloatC1,
typename DPtrsGlobal, typename DPtrsGlobal,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
...@@ -25,6 +27,8 @@ template <typename GridwiseGemm, ...@@ -25,6 +27,8 @@ template <typename GridwiseGemm,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename DGridDescriptor_MBlock_MPerBlock, typename DGridDescriptor_MBlock_MPerBlock,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
...@@ -32,10 +36,12 @@ __global__ void ...@@ -32,10 +36,12 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_reduce_xdl_cshuffle_v1( kernel_gemm_bias_add_reduce_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const FloatC0* __restrict__ p_c0_grid,
const FloatC1* __restrict__ p_c1_grid,
DPtrsGlobal p_ds_grid, DPtrsGlobal p_ds_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
...@@ -46,6 +52,10 @@ __global__ void ...@@ -46,6 +52,10 @@ __global__ void
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c0_grid_desc_mblock_mperblock_nblock_nperblock,
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c1_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock, const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
...@@ -55,6 +65,8 @@ __global__ void ...@@ -55,6 +65,8 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_c0_grid,
p_c1_grid,
p_ds_grid, p_ds_grid,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -65,12 +77,16 @@ __global__ void ...@@ -65,12 +77,16 @@ __global__ void
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
c0_grid_desc_mblock_mperblock_nblock_nperblock,
c1_grid_desc_mblock_mperblock_nblock_nperblock,
d_grid_desc_mblock_mperblock, d_grid_desc_mblock_mperblock,
block_2_ctile_map); block_2_ctile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = p_c0_grid;
ignore = p_c1_grid;
ignore = p_ds_grid; ignore = p_ds_grid;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
...@@ -80,6 +96,8 @@ __global__ void ...@@ -80,6 +96,8 @@ __global__ void
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = c0_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = d_grid_desc_mblock_mperblock; ignore = d_grid_desc_mblock_mperblock;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
...@@ -89,6 +107,8 @@ template <typename FloatAB, ...@@ -89,6 +107,8 @@ template <typename FloatAB,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
typename FloatC, typename FloatC,
typename FloatC0,
typename FloatC1,
typename FloatReduceAcc, typename FloatReduceAcc,
typename DPtrsGlobal, typename DPtrsGlobal,
typename AElementwiseOperation, typename AElementwiseOperation,
...@@ -102,6 +122,8 @@ template <typename FloatAB, ...@@ -102,6 +122,8 @@ template <typename FloatAB,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
typename C0GridDesc_M_N,
typename C1GridDesc_M_N,
typename DGridDesc_M, typename DGridDesc_M,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
...@@ -138,7 +160,7 @@ template <typename FloatAB, ...@@ -138,7 +160,7 @@ template <typename FloatAB,
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
LoopScheduler LoopSched> LoopScheduler LoopSched>
struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -268,8 +290,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -268,8 +290,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
} }
template <typename CGridDesc_M_N_>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N_& c_grid_desc_m_n)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0); const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1); const auto N = c_grid_desc_m_n.GetLength(I1);
...@@ -313,6 +336,12 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -313,6 +336,12 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C0GridDesc_M_N{}))>;
using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C1GridDesc_M_N{}))>;
using DGridDescriptor_MBlock_MPerBlock = using DGridDescriptor_MBlock_MPerBlock =
remove_cvref_t<decltype(MakeDGridDescriptor_MBlock_MPerBlock(DGridDesc_M{}))>; remove_cvref_t<decltype(MakeDGridDescriptor_MBlock_MPerBlock(DGridDesc_M{}))>;
...@@ -323,6 +352,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -323,6 +352,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const FloatC0* __restrict__ p_c0_grid,
const FloatC1* __restrict__ p_c1_grid,
DPtrsGlobal p_ds_grid, DPtrsGlobal p_ds_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
...@@ -334,6 +365,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -334,6 +365,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c0_grid_desc_mblock_mperblock_nblock_nperblock,
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c1_grid_desc_mblock_mperblock_nblock_nperblock,
const DGridDescriptor_MBlock_MPerBlock& d_grid_desc_mblock_mperblock, const DGridDescriptor_MBlock_MPerBlock& d_grid_desc_mblock_mperblock,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
...@@ -343,6 +378,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -343,6 +378,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto c0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c0_grid, c0_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto c1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c1_grid, c1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
...@@ -610,32 +649,6 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -610,32 +649,6 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
n_thread_data_on_block_idx[I2]), n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}}; ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
FloatC, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_element_op};
// space filling curve for threadwise C in VGPR // space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr = constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>, SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
...@@ -759,14 +772,80 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -759,14 +772,80 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}, },
Number<p_ds_grid.Size()>{}); Number<p_ds_grid.Size()>{});
// c0 and c1
constexpr auto c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<mreduce_per_thread>{}, I1, Number<nreduce_per_thread>{}));
constexpr auto c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock;
auto c01_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatC0,
FloatReduceAcc,
decltype(c0_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>,
Sequence<0, 1, 2, 3>,
3,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
1,
true>(
c0_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(I0,
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
I0,
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]));
auto c1_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatC1,
FloatReduceAcc,
decltype(c1_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>,
Sequence<0, 1, 2, 3>,
3,
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
1,
true>(
c1_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(I0,
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
I0,
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]));
constexpr auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<mreduce_per_thread>{}, I1, Number<nreduce_per_thread>{}));
auto c_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
FloatReduceAcc,
FloatC,
decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
tensor_operation::element_wise::PassThrough,
Sequence<I1, mreduce_per_thread, I1, nreduce_per_thread>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
3, // DstVectorDim
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
InMemoryDataOperationEnum::Set,
1,
true>{
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(I0,
m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0],
I0,
n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]),
tensor_operation::element_wise::PassThrough{}};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) { static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS // each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id), sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
...@@ -774,17 +853,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -774,17 +853,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf); c_shuffle_block_buf);
// make sure it's safe to read from LDS // make sure it's safe to write to LDS
block_sync_lds(); block_sync_lds();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
// TODO - extract following into reduction_blockwise
{ {
c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
c_shuffle_block_buf, c_shuffle_block_buf,
...@@ -792,6 +862,37 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -792,6 +862,37 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_tuple(I0, I0), make_tuple(I0, I0),
c_reduce_thread_buf); c_reduce_thread_buf);
c0_thread_copy_global_to_vgpr.Run(
c0_grid_desc_mblock_mperblock_nblock_nperblock,
c0_grid_buf,
c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
c01_thread_buf);
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) {
FloatReduceAcc out;
c_element_op(out, c_reduce_thread_buf(i) + c01_thread_buf(i));
c_reduce_thread_buf(i) = out;
});
c1_thread_copy_global_to_vgpr.Run(
c1_grid_desc_mblock_mperblock_nblock_nperblock,
c1_grid_buf,
c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
c01_thread_buf);
static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}(
[&](auto i) { c_reduce_thread_buf(i) += c01_thread_buf(i); });
c_reduce_thread_copy_vgpr_to_global.Run(
c_reduce_thread_desc_mblock_mperblock_nblock_nperblock,
make_tuple(I0, I0, I0, I0),
c_reduce_thread_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) { static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) {
auto& p_d_grid = p_ds_grid[In]; auto& p_d_grid = p_ds_grid[In];
...@@ -858,13 +959,19 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -858,13 +959,19 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C // move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( c_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
}
});
// Reduction // move on C0
c0_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
c0_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
// move on C1
c1_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
c1_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
} }
});
} // Reduction
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment