Commit e71afee2 authored by Jing Zhang's avatar Jing Zhang
Browse files

add multiD support into batched_gemm_c_permute

parent 85978e02
......@@ -178,14 +178,17 @@ int main(int argc, char* argv[])
// do GEM
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
{},
static_cast<EDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
stride_A,
stride_B,
{},
batch_stride_A,
batch_stride_B,
{},
batched_gemm_c_permute_desc,
batch_count,
a_element_op,
......
......@@ -16,26 +16,32 @@ struct BatchedGemmCPermuteDesc
template <typename ALayout,
typename BLayout,
typename DELayout,
typename DLayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceBatchedGemmCPermute : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
std::array<index_t, NumDTensor> stride_Ds,
index_t batch_stride_A,
index_t batch_stride_B,
std::array<index_t, NumDTensor> batch_stride_Ds,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
index_t BatchCount,
AElementwiseOperation a_element_op,
......@@ -45,26 +51,6 @@ struct DeviceBatchedGemmCPermute : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename ALayout,
typename BLayout,
typename DELayout,
typename ADataType,
typename BDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
using DeviceBatchedGemmCPermutePtr =
std::unique_ptr<DeviceBatchedGemmCPermute<ALayout,
BLayout,
DELayout,
ADataType,
BDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -45,10 +45,12 @@ namespace device {
*/
template <typename GridwiseGemm,
typename FloatAB,
typename FloatDsPointer,
typename FloatC,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
......@@ -61,12 +63,15 @@ __global__ void
#endif
kernel_batched_gemm_c_permute_xdl(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatDsPointer p_ds_grid,
FloatC* __restrict__ p_e_grid,
const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
......@@ -87,10 +92,19 @@ __global__ void
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_a_grid + a_batch_offset,
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
FloatDsPointer p_ds_grid_grp;
static constexpr index_t NumDTensor =
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
ck::Tuple<>{},
p_ds_grid_grp,
p_e_grid + c_batch_offset,
p_shared,
a_element_op,
......@@ -98,19 +112,18 @@ __global__ void
cde_element_op,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
ck::StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
0>{},
c_grid_desc_mblock_mperblock_nblock_nperblock,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_e_grid;
ignore = p_ds_grid, ignore = p_e_grid;
ignore = batch_count;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock,
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
......@@ -121,7 +134,7 @@ __global__ void
template <typename ALayout,
typename BLayout,
typename DELayout,
typename DLayout,
typename ADataType,
typename BDataType,
typename GemmAccDataType,
......@@ -164,9 +177,10 @@ template <typename ALayout,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
BLayout,
DELayout,
DLayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
......@@ -175,6 +189,8 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
using DeviceOp = DeviceBatchedGemmCPermuteXdl;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
......@@ -385,13 +401,21 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
}
}
static auto
MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N)
static auto MakeDGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, DLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(stride_M, stride_N));
make_tuple(StrideE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, DLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE));
}
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
......@@ -435,6 +459,56 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
}
}
static auto
MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N)
{
const auto e_grid_desc_mraw_nraw = [&]() {
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(stride_M, stride_N));
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(e_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
e_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
e_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return e_grid_desc_mraw_nraw;
}
}
static auto MakeEGridDescriptor_G0_G1_M_N(index_t G0,
index_t G1,
index_t MRaw,
......@@ -509,23 +583,34 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch(index_t Batchstride_A,
index_t Batchstride_B,
ComputePtrOffsetOfStridedBatch(index_t batch_stride_A,
index_t batch_stride_B,
std::array<ck::index_t, NumDTensor> BatchStride_Ds,
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n)
: Batchstride_A_(Batchstride_A),
Batchstride_B_(Batchstride_B),
: batch_stride_A_(batch_stride_A),
batch_stride_B_(batch_stride_B),
batch_stride_Ds(BatchStride_Ds),
e_grid_desc_g0_g1_m_n_(e_grid_desc_g0_g1_m_n)
{
}
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(Batchstride_A_);
return g_idx * static_cast<long_index_t>(batch_stride_A_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(Batchstride_B_);
return g_idx * static_cast<long_index_t>(batch_stride_B_);
}
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
{
std::array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_offset[i] = g_idx * static_cast<long_index_t>(batch_stride_Ds[i]);
});
return ds_offset;
}
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
......@@ -537,8 +622,9 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
}
private:
index_t Batchstride_A_;
index_t Batchstride_B_;
index_t batch_stride_A_;
index_t batch_stride_B_;
std::array<ck::index_t, NumDTensor> batch_stride_Ds;
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
};
......@@ -588,32 +674,36 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}));
using Block2CTileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
// Argument
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
EDataType* p_e_grid,
Argument(const void* p_a_grid,
const void* p_b_grid,
std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
std::array<index_t, NumDTensor> stride_Ds,
index_t batch_stride_A,
index_t batch_stride_B,
std::array<index_t, NumDTensor> batch_stride_Ds,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
index_t BatchCount,
index_t batch_count,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_e_grid_{p_e_grid},
BatchCount_(BatchCount),
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{}, // FIXME
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
batch_count_(batch_count),
a_grid_desc_ak0_m_ak1_{
DeviceBatchedGemmCPermuteXdl::MakeAGridDescriptor_AK0_M_AK1(M, K, stride_A)},
b_grid_desc_bk0_n_bk1_{
......@@ -632,8 +722,10 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
batched_gemm_c_permute_desc.stride_G1_,
batched_gemm_c_permute_desc.stride_M_,
batched_gemm_c_permute_desc.stride_N_)},
c_grid_desc_mblock_mperblock_nblock_nperblock{},
compute_ptr_offset_of_batch_{batch_stride_A, batch_stride_B, e_grid_desc_g0_g1_m_n_},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock{},
compute_ptr_offset_of_batch_{
batch_stride_A, batch_stride_B, batch_stride_Ds, e_grid_desc_g0_g1_m_n_},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
......@@ -645,22 +737,44 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
e_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock =
e_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
const auto d_grid_desc_m_n =
DeviceOp::MakeDGridDescriptor_M_N(M, N, stride_Ds[i]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
d_grid_desc_m_n);
});
}
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
index_t BatchCount_;
index_t batch_count_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
EGridDesc_M_N e_grid_desc_m_n_;
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock;
StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>
ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock;
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
Block2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_;
......@@ -701,7 +815,7 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.BatchCount_;
arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.batch_count_;
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
......@@ -712,9 +826,13 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
const auto kernel = kernel_batched_gemm_c_permute_xdl<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
EDataType,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
ck::StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
AElementwiseOperation,
BElementwiseOperation,
......@@ -730,11 +848,13 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.BatchCount_,
arg.batch_count_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
......@@ -778,32 +898,38 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
EDataType* p_c,
std::array<const void*, NumDTensor> p_ds,
EDataType* p_e,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
std::array<index_t, NumDTensor> stride_Ds,
index_t batch_stride_A,
index_t batch_stride_B,
std::array<index_t, NumDTensor> batch_stride_Ds,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
index_t BatchCount,
index_t batch_count,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{p_a,
p_b,
p_c,
p_ds,
p_e,
M,
N,
K,
stride_A,
stride_B,
stride_Ds,
batch_stride_A,
batch_stride_B,
batch_stride_Ds,
batched_gemm_c_permute_desc,
BatchCount,
batch_count,
a_element_op,
b_element_op,
cde_element_op};
......@@ -815,32 +941,38 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
std::array<index_t, NumDTensor> stride_Ds,
index_t batch_stride_A,
index_t batch_stride_B,
std::array<index_t, NumDTensor> batch_stride_Ds,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
index_t BatchCount,
index_t batch_count,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<EDataType*>(p_c),
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
M,
N,
K,
stride_A,
stride_B,
stride_Ds,
batch_stride_A,
batch_stride_B,
batch_stride_Ds,
batched_gemm_c_permute_desc,
BatchCount,
batch_count,
a_element_op,
b_element_op,
cde_element_op);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment