Commit e70aa44a authored by Jing Zhang's avatar Jing Zhang
Browse files

add batch_strides into bmm_c_permute

parent ffd6943e
...@@ -26,35 +26,36 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -26,35 +26,36 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = ck::half_t; using ADataType = F16;
using BDataType = ck::half_t; using BDataType = F16;
using CDataType = ck::half_t; using AccDataType = F32;
using AccDataType = float; using CShuffleDataType = F16;
using DsDataType = ck::Tuple<>;
using ALayout = ck::tensor_layout::gemm::RowMajor; using EDataType = F16;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor; using ALayout = Row;
using BLayout = Col;
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using ELayout = Row;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough;
// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; using CDEElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// static constexpr auto MNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; // static constexpr auto MNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; // static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmCPermuteXdl using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmCPermuteXdl
//######| ALayout| BLayout| AData| BData| CData| AccData| A| B| C| GEMM| Num| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// < Row, Col, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, MNPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; < ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
< Row, Col, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, MNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on // clang-format on
using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
ReferenceBatchedGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>; ReferenceBatchedGemm<ADataType, BDataType, EDataType, AElementOp, BElementOp, CDEElementOp>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -62,15 +63,18 @@ int main(int argc, char* argv[]) ...@@ -62,15 +63,18 @@ int main(int argc, char* argv[])
int init_method = 1; int init_method = 1;
bool time_kernel = false; bool time_kernel = false;
const int M = 88; const int M = 256;
const int N = 64; const int N = 128;
const int K = 88; const int K = 64;
const int stride_A = K; const int stride_A = K;
const int stride_B = K; const int stride_B = K;
const int G0 = 1024; const int batch_stride_A = M * K;
const int G1 = 10; const int batch_stride_B = K * N;
const int G0 = 16;
const int G1 = 8;
const int batch_count = G0 * G1; const int batch_count = G0 * G1;
...@@ -102,21 +106,24 @@ int main(int argc, char* argv[]) ...@@ -102,21 +106,24 @@ int main(int argc, char* argv[])
std::size_t row, std::size_t row,
std::size_t col, std::size_t col,
std::size_t stride, std::size_t stride,
std::size_t batch_stride,
auto layout) { auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value) if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}), return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}),
std::vector<std::size_t>({row * stride, stride, 1})); std::vector<std::size_t>({batch_stride, stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}), return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}),
std::vector<std::size_t>({col * stride, 1, stride})); std::vector<std::size_t>({batch_stride, 1, stride}));
} }
}; };
Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(batch_count, M, K, stride_A, ALayout{})); Tensor<ADataType> a_g_m_k(
Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(batch_count, K, N, stride_B, BLayout{})); f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, ALayout{}));
Tensor<BDataType> b_g_k_n(
f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{}));
auto f_host_c_tensor_descriptor = [](std::size_t G0_, auto f_host_c_tensor_descriptor = [](std::size_t G0_,
std::size_t G1_, std::size_t G1_,
...@@ -131,10 +138,10 @@ int main(int argc, char* argv[]) ...@@ -131,10 +138,10 @@ int main(int argc, char* argv[])
std::vector<std::size_t>({stride_G0_, stride_G1_, stride_M_, stride_N_})); std::vector<std::size_t>({stride_G0_, stride_G1_, stride_M_, stride_N_}));
}; };
Tensor<CDataType> c_g0_g1_m_n_host_result( Tensor<EDataType> c_g0_g1_m_n_host_result(
f_host_c_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N)); f_host_c_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N));
Tensor<CDataType> c_g0_g1_m_n_device_result( Tensor<EDataType> c_g0_g1_m_n_device_result(
f_host_c_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N)); f_host_c_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
...@@ -156,32 +163,34 @@ int main(int argc, char* argv[]) ...@@ -156,32 +163,34 @@ int main(int argc, char* argv[])
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_g0_g1_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(EDataType) * c_g0_g1_m_n_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_g_m_k.mData.data()); a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data()); b_device_buf.ToDevice(b_g_k_n.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto cde_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
// do GEMM // do GEM
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<EDataType*>(c_device_buf.GetDeviceBuffer()),
M, M,
N, N,
K, K,
stride_A, stride_A,
stride_B, stride_B,
batch_stride_A,
batch_stride_B,
batched_gemm_c_permute_desc, batched_gemm_c_permute_desc,
batch_count,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, cde_element_op);
batch_count);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -195,7 +204,7 @@ int main(int argc, char* argv[]) ...@@ -195,7 +204,7 @@ int main(int argc, char* argv[])
std::size_t flop = std::size_t(2) * batch_count * M * N * K; std::size_t flop = std::size_t(2) * batch_count * M * N * K;
std::size_t num_btype = sizeof(ADataType) * batch_count * M * K + std::size_t num_btype = sizeof(ADataType) * batch_count * M * K +
sizeof(BDataType) * batch_count * K * N + sizeof(BDataType) * batch_count * K * N +
sizeof(CDataType) * batch_count * M * N; sizeof(EDataType) * batch_count * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -213,11 +222,11 @@ int main(int argc, char* argv[]) ...@@ -213,11 +222,11 @@ int main(int argc, char* argv[])
auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
auto ref_invoker = ref_batched_gemm.MakeInvoker(); auto ref_invoker = ref_batched_gemm.MakeInvoker();
Tensor<CDataType> c_g_m_n_host_result = HostTensorDescriptor( Tensor<EDataType> c_g_m_n_host_result = HostTensorDescriptor(
std::vector<std::size_t>({batch_count, M, N}), std::vector<std::size_t>({M * N, N, 1})); std::vector<std::size_t>({batch_count, M, N}), std::vector<std::size_t>({M * N, N, 1}));
auto ref_argument = ref_batched_gemm.MakeArgument( auto ref_argument = ref_batched_gemm.MakeArgument(
a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op); a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, cde_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
......
...@@ -96,24 +96,27 @@ int main(int argc, char* argv[]) ...@@ -96,24 +96,27 @@ int main(int argc, char* argv[])
std::size_t row, std::size_t row,
std::size_t col, std::size_t col,
std::size_t stride, std::size_t stride,
std::size_t batch_stride,
auto layout) { auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value) if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}), return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}),
std::vector<std::size_t>({row * stride, stride, 1})); std::vector<std::size_t>({batch_stride, stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}), return HostTensorDescriptor(std::vector<std::size_t>({batch_count_, row, col}),
std::vector<std::size_t>({col * stride, 1, stride})); std::vector<std::size_t>({batch_stride, 1, stride}));
} }
}; };
Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(batch_count, M, K, stride_A, ALayout{})); Tensor<ADataType> a_g_m_k(
Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(batch_count, K, N, stride_B, BLayout{})); f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, ALayout{}));
Tensor<BDataType> b_g_k_n(
f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{}));
Tensor<EDataType> e_g_m_n_device_result( Tensor<EDataType> e_g_m_n_device_result(
f_host_tensor_descriptor(batch_count, M, N, stride_C, ELayout{})); f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, ELayout{}));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
...@@ -198,7 +201,7 @@ int main(int argc, char* argv[]) ...@@ -198,7 +201,7 @@ int main(int argc, char* argv[])
auto ref_invoker = ref_batched_gemm.MakeInvoker(); auto ref_invoker = ref_batched_gemm.MakeInvoker();
Tensor<EDataType> e_g_m_n_host_result( Tensor<EDataType> e_g_m_n_host_result(
f_host_tensor_descriptor(batch_count, M, N, stride_C, ELayout{})); f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, ELayout{}));
auto ref_argument = ref_batched_gemm.MakeArgument( auto ref_argument = ref_batched_gemm.MakeArgument(
a_g_m_k, b_g_k_n, e_g_m_n_host_result, a_element_op, b_element_op, cde_element_op); a_g_m_k, b_g_k_n, e_g_m_n_host_result, a_element_op, b_element_op, cde_element_op);
......
...@@ -14,9 +14,15 @@ struct BatchedGemmCPermuteDesc ...@@ -14,9 +14,15 @@ struct BatchedGemmCPermuteDesc
ck::index_t stride_G0_, stride_G1_, stride_M_, stride_N_; ck::index_t stride_G0_, stride_G1_, stride_M_, stride_N_;
}; };
template <typename AElementwiseOperation, template <typename ALayout,
typename BLayout,
typename DELayout,
typename ADataType,
typename BDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CDEElementwiseOperation>
struct DeviceBatchedGemmCPermute : public BaseOperator struct DeviceBatchedGemmCPermute : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
...@@ -28,20 +34,36 @@ struct DeviceBatchedGemmCPermute : public BaseOperator ...@@ -28,20 +34,36 @@ struct DeviceBatchedGemmCPermute : public BaseOperator
index_t K, index_t K,
index_t stride_A, index_t stride_A,
index_t stride_B, index_t stride_B,
index_t batch_stride_A,
index_t batch_stride_B,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
index_t BatchCount,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CDEElementwiseOperation c_element_op) = 0;
ck::index_t BatchCount) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename AElementwiseOperation, template <typename ALayout,
typename BLayout,
typename DELayout,
typename ADataType,
typename BDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CDEElementwiseOperation>
using DeviceBatchedGemmCPermutePtr = std::unique_ptr< using DeviceBatchedGemmCPermutePtr =
DeviceBatchedGemmCPermute<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>; std::unique_ptr<DeviceBatchedGemmCPermute<ALayout,
BLayout,
DELayout,
ADataType,
BDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_c_permute.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_c_permute.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/device_utility/device_prop.hpp" #include "ck/device_utility/device_prop.hpp"
...@@ -45,12 +46,12 @@ namespace device { ...@@ -45,12 +46,12 @@ namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename AGridDesc_K0_M_K1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_K0_N_K1, typename BGridDesc_BK0_N_BK1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CDEElementwiseOperation,
typename ComputePtrOffsetOfBatch, typename ComputePtrOffsetOfBatch,
typename Block2CTileMap, typename Block2CTileMap,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
...@@ -60,15 +61,15 @@ __global__ void ...@@ -60,15 +61,15 @@ __global__ void
#endif #endif
kernel_batched_gemm_c_permute_xdl(const FloatAB* __restrict__ p_a_grid, kernel_batched_gemm_c_permute_xdl(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_e_grid,
const index_t batch_count, const index_t batch_count,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CDEElementwiseOperation cde_element_op,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
...@@ -90,11 +91,11 @@ __global__ void ...@@ -90,11 +91,11 @@ __global__ void
p_a_grid + a_batch_offset, p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
ck::Tuple<>{}, ck::Tuple<>{},
p_c_grid + c_batch_offset, p_e_grid + c_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, cde_element_op,
a_grid_desc_k0_m_k1, a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1, b_grid_desc_k0_n_k1,
ck::StaticallyIndexedArray< ck::StaticallyIndexedArray<
...@@ -105,14 +106,14 @@ __global__ void ...@@ -105,14 +106,14 @@ __global__ void
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_c_grid; ignore = p_e_grid;
ignore = batch_count; ignore = batch_count;
ignore = a_grid_desc_k0_m_k1; ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1; ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = cde_element_op;
ignore = compute_ptr_offset_of_batch; ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
#endif #endif
...@@ -120,48 +121,60 @@ __global__ void ...@@ -120,48 +121,60 @@ __global__ void
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename DELayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename GemmAccDataType,
typename AccDataType, typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CDEElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
ck::index_t NumPrefetch, index_t NumGemmKPrefetchStage,
ck::index_t BlockSize, index_t BlockSize,
ck::index_t MPerBlock, index_t MPerBlock,
ck::index_t NPerBlock, index_t NPerBlock,
ck::index_t KPerBlock, index_t KPerBlock,
ck::index_t AK1, index_t AK1,
ck::index_t BK1, index_t BK1,
ck::index_t MPerXDL, index_t MPerXDL,
ck::index_t NPerXDL, index_t NPerXDL,
ck::index_t MXdlPerWave, index_t MXdlPerWave,
ck::index_t NXdlPerWave, index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1, typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1, index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsAddExtraM, bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1, typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder, typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1, index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsAddExtraN, bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementwiseOperation, struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<ALayout,
BLayout,
DELayout,
ADataType,
BDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation> CDEElementwiseOperation>
{ {
using DeviceOp = DeviceBatchedGemmCPermuteXdl;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -373,7 +386,7 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -373,7 +386,7 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
} }
static auto static auto
MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N) MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N)
{ {
const auto c_grid_desc_mraw_nraw = [&]() { const auto c_grid_desc_mraw_nraw = [&]() {
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
...@@ -489,9 +502,9 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -489,9 +502,9 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
} }
} }
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1)); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1, 1));
using EGridDesc_G0_G1_M_N = decltype(MakeEGridDescriptor_G0_G1_M_N(1, 1, 1, 1, 1, 1, 1, 1)); using EGridDesc_G0_G1_M_N = decltype(MakeEGridDescriptor_G0_G1_M_N(1, 1, 1, 1, 1, 1, 1, 1));
struct ComputePtrOffsetOfStridedBatch struct ComputePtrOffsetOfStridedBatch
...@@ -531,18 +544,18 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -531,18 +544,18 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, GemmAccDataType,
CDataType, // CShuffleDataType, CShuffleDataType,
ck::Tuple<>, // DsDataType, DsDataType,
CDataType, // EDataType, EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_AK0_M_AK1,
BGridDesc_K0_N_K1, BGridDesc_BK0_N_BK1,
CGridDesc_M_N, EGridDesc_M_N,
NumPrefetch, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
...@@ -553,22 +566,22 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -553,22 +566,22 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
NPerXDL, NPerXDL,
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_AK1,
false, // AThreadTransferSrcResetCoordinateAfterRun, false,
ABlockLdsAddExtraM, ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_BK1,
false, // BThreadTransferSrcResetCoordinateAfterRun, false,
BBlockLdsAddExtraN, BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
...@@ -576,7 +589,7 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -576,7 +589,7 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
LoopSched>; LoopSched>;
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype( using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}));
using Block2CTileMap = typename GridwiseGemm::DefaultBlock2ETileMap; using Block2CTileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
// Argument // Argument
...@@ -584,26 +597,28 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -584,26 +597,28 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
{ {
Argument(const ADataType* p_a_grid, Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid, const BDataType* p_b_grid,
CDataType* p_c_grid, EDataType* p_e_grid,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t stride_A, index_t stride_A,
index_t stride_B, index_t stride_B,
index_t batch_stride_A,
index_t batch_stride_B,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
index_t BatchCount,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CDEElementwiseOperation cde_element_op)
index_t BatchCount)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_e_grid_{p_e_grid},
BatchCount_(BatchCount), BatchCount_(BatchCount),
a_grid_desc_k0_m_k1_{ a_grid_desc_ak0_m_ak1_{
DeviceBatchedGemmCPermuteXdl::MakeAGridDescriptor_AK0_M_AK1(M, K, stride_A)}, DeviceBatchedGemmCPermuteXdl::MakeAGridDescriptor_AK0_M_AK1(M, K, stride_A)},
b_grid_desc_k0_n_k1_{ b_grid_desc_bk0_n_bk1_{
DeviceBatchedGemmCPermuteXdl::MakeBGridDescriptor_BK0_N_BK1(K, N, stride_B)}, DeviceBatchedGemmCPermuteXdl::MakeBGridDescriptor_BK0_N_BK1(K, N, stride_B)},
c_grid_desc_m_n_{DeviceBatchedGemmCPermuteXdl::MakeCGridDescriptor_M_N( e_grid_desc_m_n_{DeviceBatchedGemmCPermuteXdl::MakeEGridDescriptor_M_N(
batched_gemm_c_permute_desc.M_, batched_gemm_c_permute_desc.M_,
batched_gemm_c_permute_desc.N_, batched_gemm_c_permute_desc.N_,
batched_gemm_c_permute_desc.stride_M_, batched_gemm_c_permute_desc.stride_M_,
...@@ -618,42 +633,39 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -618,42 +633,39 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
batched_gemm_c_permute_desc.stride_M_, batched_gemm_c_permute_desc.stride_M_,
batched_gemm_c_permute_desc.stride_N_)}, batched_gemm_c_permute_desc.stride_N_)},
c_grid_desc_mblock_mperblock_nblock_nperblock{}, c_grid_desc_mblock_mperblock_nblock_nperblock{},
compute_ptr_offset_of_batch_{ compute_ptr_offset_of_batch_{batch_stride_A, batch_stride_B, e_grid_desc_g0_g1_m_n_},
type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()), block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
e_grid_desc_g0_g1_m_n_},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} cde_element_op_{cde_element_op}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_k0_n_k1_, b_grid_desc_bk0_n_bk1_,
c_grid_desc_m_n_, e_grid_desc_m_n_,
block_2_ctile_map_)) block_2_ctile_map_))
{ {
c_grid_desc_mblock_mperblock_nblock_nperblock = c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_); e_grid_desc_m_n_);
} }
} }
// private: // private:
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; EDataType* p_e_grid_;
index_t BatchCount_; index_t BatchCount_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_;
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_; EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock; CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock;
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CDEElementwiseOperation cde_element_op_;
}; };
// Invoker // Invoker
...@@ -664,21 +676,23 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -664,21 +676,23 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.e_grid_desc_m_n_{" << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.block_2_ctile_map_)) arg.block_2_ctile_map_))
{ {
throw std::runtime_error( throw std::runtime_error(
...@@ -687,10 +701,10 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -687,10 +701,10 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
} }
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.BatchCount_;
const auto K = const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
float ave_time = 0; float ave_time = 0;
...@@ -698,13 +712,13 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -698,13 +712,13 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
const auto kernel = kernel_batched_gemm_c_permute_xdl< const auto kernel = kernel_batched_gemm_c_permute_xdl<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, EDataType,
remove_reference_t<DeviceBatchedGemmCPermuteXdl::AGridDesc_K0_M_K1>, AGridDesc_AK0_M_AK1,
remove_reference_t<DeviceBatchedGemmCPermuteXdl::BGridDesc_K0_N_K1>, BGridDesc_BK0_N_BK1,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CDEElementwiseOperation,
ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2CTileMap>,
has_main_k_block_loop_>; has_main_k_block_loop_>;
...@@ -716,14 +730,14 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -716,14 +730,14 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_e_grid_,
arg.BatchCount_, arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock, arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.cde_element_op_,
arg.compute_ptr_offset_of_batch_, arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
}; };
...@@ -756,31 +770,27 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -756,31 +770,27 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.block_2_ctile_map_); arg.block_2_ctile_map_);
} }
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b, const BDataType* p_b,
CDataType* p_c, EDataType* p_c,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t stride_A, index_t stride_A,
index_t stride_B, index_t stride_B,
index_t batch_stride_A,
index_t batch_stride_B,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
index_t BatchCount,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CDEElementwiseOperation cde_element_op)
index_t BatchCount)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -790,11 +800,13 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -790,11 +800,13 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
K, K,
stride_A, stride_A,
stride_B, stride_B,
batch_stride_A,
batch_stride_B,
batched_gemm_c_permute_desc, batched_gemm_c_permute_desc,
BatchCount,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, cde_element_op};
BatchCount};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -809,25 +821,29 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -809,25 +821,29 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
index_t K, index_t K,
index_t stride_A, index_t stride_A,
index_t stride_B, index_t stride_B,
index_t batch_stride_A,
index_t batch_stride_B,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
index_t BatchCount,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CDEElementwiseOperation cde_element_op) override
index_t BatchCount) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c), static_cast<EDataType*>(p_c),
M, M,
N, N,
K, K,
stride_A, stride_A,
stride_B, stride_B,
batch_stride_A,
batch_stride_B,
batched_gemm_c_permute_desc, batched_gemm_c_permute_desc,
BatchCount,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, cde_element_op);
BatchCount);
} }
// polymorphic // polymorphic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment