Commit e71afee2 authored by Jing Zhang's avatar Jing Zhang
Browse files

add multiD support into batched_gemm_c_permute

parent 85978e02
...@@ -178,14 +178,17 @@ int main(int argc, char* argv[]) ...@@ -178,14 +178,17 @@ int main(int argc, char* argv[])
// do GEM // do GEM
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
{},
static_cast<EDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<EDataType*>(c_device_buf.GetDeviceBuffer()),
M, M,
N, N,
K, K,
stride_A, stride_A,
stride_B, stride_B,
{},
batch_stride_A, batch_stride_A,
batch_stride_B, batch_stride_B,
{},
batched_gemm_c_permute_desc, batched_gemm_c_permute_desc,
batch_count, batch_count,
a_element_op, a_element_op,
......
...@@ -16,26 +16,32 @@ struct BatchedGemmCPermuteDesc ...@@ -16,26 +16,32 @@ struct BatchedGemmCPermuteDesc
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename DELayout, typename DLayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename DsDataType,
typename EDataType, typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation> typename CDEElementwiseOperation>
struct DeviceBatchedGemmCPermute : public BaseOperator struct DeviceBatchedGemmCPermute : public BaseOperator
{ {
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_c, void* p_c,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t stride_A, index_t stride_A,
index_t stride_B, index_t stride_B,
std::array<index_t, NumDTensor> stride_Ds,
index_t batch_stride_A, index_t batch_stride_A,
index_t batch_stride_B, index_t batch_stride_B,
std::array<index_t, NumDTensor> batch_stride_Ds,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
index_t BatchCount, index_t BatchCount,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
...@@ -45,26 +51,6 @@ struct DeviceBatchedGemmCPermute : public BaseOperator ...@@ -45,26 +51,6 @@ struct DeviceBatchedGemmCPermute : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename ALayout,
typename BLayout,
typename DELayout,
typename ADataType,
typename BDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
using DeviceBatchedGemmCPermutePtr =
std::unique_ptr<DeviceBatchedGemmCPermute<ALayout,
BLayout,
DELayout,
ADataType,
BDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -45,10 +45,12 @@ namespace device { ...@@ -45,10 +45,12 @@ namespace device {
*/ */
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatDsPointer,
typename FloatC, typename FloatC,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
...@@ -61,12 +63,15 @@ __global__ void ...@@ -61,12 +63,15 @@ __global__ void
#endif #endif
kernel_batched_gemm_c_permute_xdl(const FloatAB* __restrict__ p_a_grid, kernel_batched_gemm_c_permute_xdl(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatDsPointer p_ds_grid,
FloatC* __restrict__ p_e_grid, FloatC* __restrict__ p_e_grid,
const index_t batch_count, const index_t batch_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
...@@ -87,10 +92,19 @@ __global__ void ...@@ -87,10 +92,19 @@ __global__ void
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>( const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
p_a_grid + a_batch_offset,
FloatDsPointer p_ds_grid_grp;
static constexpr index_t NumDTensor =
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
ck::Tuple<>{}, p_ds_grid_grp,
p_e_grid + c_batch_offset, p_e_grid + c_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -98,19 +112,18 @@ __global__ void ...@@ -98,19 +112,18 @@ __global__ void
cde_element_op, cde_element_op,
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
ck::StaticallyIndexedArray< ds_grid_desc_mblock_mperblock_nblock_nperblock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, e_grid_desc_mblock_mperblock_nblock_nperblock,
0>{},
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map); block_2_ctile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_e_grid; ignore = p_ds_grid, ignore = p_e_grid;
ignore = batch_count; ignore = batch_count;
ignore = a_grid_desc_k0_m_k1; ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1; ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock,
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = cde_element_op; ignore = cde_element_op;
...@@ -121,7 +134,7 @@ __global__ void ...@@ -121,7 +134,7 @@ __global__ void
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename DELayout, typename DLayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename GemmAccDataType, typename GemmAccDataType,
...@@ -164,9 +177,10 @@ template <typename ALayout, ...@@ -164,9 +177,10 @@ template <typename ALayout,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout, struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
BLayout, BLayout,
DELayout, DLayout,
ADataType, ADataType,
BDataType, BDataType,
DsDataType,
EDataType, EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -175,6 +189,8 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout, ...@@ -175,6 +189,8 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
using DeviceOp = DeviceBatchedGemmCPermuteXdl; using DeviceOp = DeviceBatchedGemmCPermuteXdl;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -385,13 +401,21 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout, ...@@ -385,13 +401,21 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
} }
} }
static auto static auto MakeDGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N)
{ {
const auto c_grid_desc_mraw_nraw = [&]() { const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, DLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(stride_M, stride_N)); make_tuple(StrideE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, DLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE));
}
}(); }();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
...@@ -435,6 +459,56 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout, ...@@ -435,6 +459,56 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
} }
} }
static auto
MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N)
{
const auto e_grid_desc_mraw_nraw = [&]() {
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(stride_M, stride_N));
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(e_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
e_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
e_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return e_grid_desc_mraw_nraw;
}
}
static auto MakeEGridDescriptor_G0_G1_M_N(index_t G0, static auto MakeEGridDescriptor_G0_G1_M_N(index_t G0,
index_t G1, index_t G1,
index_t MRaw, index_t MRaw,
...@@ -509,23 +583,34 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout, ...@@ -509,23 +583,34 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
struct ComputePtrOffsetOfStridedBatch struct ComputePtrOffsetOfStridedBatch
{ {
ComputePtrOffsetOfStridedBatch(index_t Batchstride_A, ComputePtrOffsetOfStridedBatch(index_t batch_stride_A,
index_t Batchstride_B, index_t batch_stride_B,
std::array<ck::index_t, NumDTensor> BatchStride_Ds,
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n) EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n)
: Batchstride_A_(Batchstride_A), : batch_stride_A_(batch_stride_A),
Batchstride_B_(Batchstride_B), batch_stride_B_(batch_stride_B),
batch_stride_Ds(BatchStride_Ds),
e_grid_desc_g0_g1_m_n_(e_grid_desc_g0_g1_m_n) e_grid_desc_g0_g1_m_n_(e_grid_desc_g0_g1_m_n)
{ {
} }
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(Batchstride_A_); return g_idx * static_cast<long_index_t>(batch_stride_A_);
} }
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(Batchstride_B_); return g_idx * static_cast<long_index_t>(batch_stride_B_);
}
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
{
std::array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_offset[i] = g_idx * static_cast<long_index_t>(batch_stride_Ds[i]);
});
return ds_offset;
} }
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
...@@ -537,8 +622,9 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout, ...@@ -537,8 +622,9 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
} }
private: private:
index_t Batchstride_A_; index_t batch_stride_A_;
index_t Batchstride_B_; index_t batch_stride_B_;
std::array<ck::index_t, NumDTensor> batch_stride_Ds;
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_; EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
}; };
...@@ -588,32 +674,36 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout, ...@@ -588,32 +674,36 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype( using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{})); GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}));
using Block2CTileMap = typename GridwiseGemm::DefaultBlock2ETileMap; using Block2CTileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const ADataType* p_a_grid, Argument(const void* p_a_grid,
const BDataType* p_b_grid, const void* p_b_grid,
EDataType* p_e_grid, std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t stride_A, index_t stride_A,
index_t stride_B, index_t stride_B,
std::array<index_t, NumDTensor> stride_Ds,
index_t batch_stride_A, index_t batch_stride_A,
index_t batch_stride_B, index_t batch_stride_B,
std::array<index_t, NumDTensor> batch_stride_Ds,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
index_t BatchCount, index_t batch_count,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) CDEElementwiseOperation cde_element_op)
: p_a_grid_{p_a_grid}, : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{p_b_grid}, p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_e_grid_{p_e_grid}, p_ds_grid_{}, // FIXME
BatchCount_(BatchCount), p_e_grid_{static_cast<EDataType*>(p_e_grid)},
batch_count_(batch_count),
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
DeviceBatchedGemmCPermuteXdl::MakeAGridDescriptor_AK0_M_AK1(M, K, stride_A)}, DeviceBatchedGemmCPermuteXdl::MakeAGridDescriptor_AK0_M_AK1(M, K, stride_A)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
...@@ -632,8 +722,10 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout, ...@@ -632,8 +722,10 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
batched_gemm_c_permute_desc.stride_G1_, batched_gemm_c_permute_desc.stride_G1_,
batched_gemm_c_permute_desc.stride_M_, batched_gemm_c_permute_desc.stride_M_,
batched_gemm_c_permute_desc.stride_N_)}, batched_gemm_c_permute_desc.stride_N_)},
c_grid_desc_mblock_mperblock_nblock_nperblock{}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
compute_ptr_offset_of_batch_{batch_stride_A, batch_stride_B, e_grid_desc_g0_g1_m_n_}, e_grid_desc_mblock_mperblock_nblock_nperblock{},
compute_ptr_offset_of_batch_{
batch_stride_A, batch_stride_B, batch_stride_Ds, e_grid_desc_g0_g1_m_n_},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
...@@ -645,22 +737,44 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout, ...@@ -645,22 +737,44 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
e_grid_desc_m_n_, e_grid_desc_m_n_,
block_2_ctile_map_)) block_2_ctile_map_))
{ {
c_grid_desc_mblock_mperblock_nblock_nperblock = e_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_); e_grid_desc_m_n_);
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
const auto d_grid_desc_m_n =
DeviceOp::MakeDGridDescriptor_M_N(M, N, stride_Ds[i]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
d_grid_desc_m_n);
});
} }
} }
// private: // private:
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_; EDataType* p_e_grid_;
index_t BatchCount_;
index_t batch_count_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
EGridDesc_M_N e_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_;
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_; EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock;
StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>
ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock;
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -701,7 +815,7 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout, ...@@ -701,7 +815,7 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
} }
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.BatchCount_; arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.batch_count_;
const auto K = const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
...@@ -712,9 +826,13 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout, ...@@ -712,9 +826,13 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
const auto kernel = kernel_batched_gemm_c_permute_xdl< const auto kernel = kernel_batched_gemm_c_permute_xdl<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
EDataType, EDataType,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
ck::StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -730,11 +848,13 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout, ...@@ -730,11 +848,13 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_, arg.p_e_grid_,
arg.BatchCount_, arg.batch_count_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.cde_element_op_, arg.cde_element_op_,
...@@ -778,32 +898,38 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout, ...@@ -778,32 +898,38 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b, const BDataType* p_b,
EDataType* p_c, std::array<const void*, NumDTensor> p_ds,
EDataType* p_e,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t stride_A, index_t stride_A,
index_t stride_B, index_t stride_B,
std::array<index_t, NumDTensor> stride_Ds,
index_t batch_stride_A, index_t batch_stride_A,
index_t batch_stride_B, index_t batch_stride_B,
std::array<index_t, NumDTensor> batch_stride_Ds,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
index_t BatchCount, index_t batch_count,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) CDEElementwiseOperation cde_element_op)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
p_c, p_ds,
p_e,
M, M,
N, N,
K, K,
stride_A, stride_A,
stride_B, stride_B,
stride_Ds,
batch_stride_A, batch_stride_A,
batch_stride_B, batch_stride_B,
batch_stride_Ds,
batched_gemm_c_permute_desc, batched_gemm_c_permute_desc,
BatchCount, batch_count,
a_element_op, a_element_op,
b_element_op, b_element_op,
cde_element_op}; cde_element_op};
...@@ -815,32 +941,38 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout, ...@@ -815,32 +941,38 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
void* p_c, std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t stride_A, index_t stride_A,
index_t stride_B, index_t stride_B,
std::array<index_t, NumDTensor> stride_Ds,
index_t batch_stride_A, index_t batch_stride_A,
index_t batch_stride_B, index_t batch_stride_B,
std::array<index_t, NumDTensor> batch_stride_Ds,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
index_t BatchCount, index_t batch_count,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override CDEElementwiseOperation cde_element_op) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(p_a,
static_cast<const BDataType*>(p_b), p_b,
static_cast<EDataType*>(p_c), p_ds,
p_e,
M, M,
N, N,
K, K,
stride_A, stride_A,
stride_B, stride_B,
stride_Ds,
batch_stride_A, batch_stride_A,
batch_stride_B, batch_stride_B,
batch_stride_Ds,
batched_gemm_c_permute_desc, batched_gemm_c_permute_desc,
BatchCount, batch_count,
a_element_op, a_element_op,
b_element_op, b_element_op,
cde_element_op); cde_element_op);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment