Commit ab04f22f authored by Jing Zhang's avatar Jing Zhang
Browse files

add c_permute

parent ef18bd98
...@@ -6,16 +6,14 @@ ...@@ -6,16 +6,14 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_c_permute.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_c_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp" #include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp" #include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -44,7 +42,7 @@ using CElementOp = ck::tensor_operation::element_wise::PassThrough; ...@@ -44,7 +42,7 @@ using CElementOp = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmCPermutationXdl using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmCPermuteXdl
//######| ALayout| BLayout| AData| BData| CData| AccData| A| B| C| GEMM| Num| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| ALayout| BLayout| AData| BData| CData| AccData| A| B| C| GEMM| Num| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
...@@ -52,13 +50,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmCPermu ...@@ -52,13 +50,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmCPermu
< Row, Col, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; < Row, Col, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on // clang-format on
using ReferenceBatchedGemmCPermutationInstance = using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
ck::tensor_operation::host::ReferenceBatchedGemmCPermutation<ADataType, ReferenceBatchedGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
BDataType,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -66,26 +59,27 @@ int main(int argc, char* argv[]) ...@@ -66,26 +59,27 @@ int main(int argc, char* argv[])
int init_method = 1; int init_method = 1;
bool time_kernel = false; bool time_kernel = false;
const int M0 = rand() % 4 + 1; // const int M = 88;
const int M1 = 256; // const int N = 64;
const int N0 = rand() % 4 + 1; // const int K = 88;
const int N1 = 256;
const int M = M0 * N1; const int M = 256;
const int N = N0 * N1; const int N = 128;
const int K = 64;
const int K = 128 * (rand() % 4 + 1);
const int stride_A = K; const int stride_A = K;
const int stride_B = K; const int stride_B = K;
// output layout [M0, N0, M1, N1] const int G0 = 1024;
const int stride_M0 = N1 * M1 * N0; const int G1 = 10;
const int stride_M1 = N1;
const int stride_N0 = N1 * M1; const int batch_count = G0 * G1;
const int stride_N1 = 1;
int batch_count = rand() % 16 + 1; // output layout - [G0, M, G1, N]
const int stride_B0 = M * G1 * N;
const int stride_B1 = N;
const int stride_M = G1 * N;
const int stride_N = 1;
if(argc == 4) if(argc == 4)
{ {
...@@ -102,8 +96,8 @@ int main(int argc, char* argv[]) ...@@ -102,8 +96,8 @@ int main(int argc, char* argv[])
} }
// GEMM shape // GEMM shape
ck::tensor_operation::device::GemmTransposeDesc gemm_transpose_desc{ ck::tensor_operation::device::BatchedGemmCPermuteDesc batched_gemm_c_permute_desc{
M, N, K, stride_A, stride_B, M0, M1, N0, N1, stride_M0, stride_M1, stride_N0, stride_N1}; G0, G1, M, N, stride_B0, stride_B1, stride_M, stride_N};
auto f_host_tensor_descriptor = [](std::size_t batch_count_, auto f_host_tensor_descriptor = [](std::size_t batch_count_,
std::size_t row, std::size_t row,
...@@ -125,30 +119,28 @@ int main(int argc, char* argv[]) ...@@ -125,30 +119,28 @@ int main(int argc, char* argv[])
Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(batch_count, M, K, stride_A, ALayout{})); Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(batch_count, M, K, stride_A, ALayout{}));
Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(batch_count, K, N, stride_B, BLayout{})); Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(batch_count, K, N, stride_B, BLayout{}));
auto f_host_c_tensor_descriptor = [](std::size_t batch_count_, auto f_host_c_tensor_descriptor = [](std::size_t B0_,
std::size_t M0_, std::size_t B1_,
std::size_t M1_, std::size_t M_,
std::size_t N0_, std::size_t N_,
std::size_t N1_, std::size_t stride_B0_,
std::size_t StrideM0_, std::size_t stride_B1_,
std::size_t StrideM1_, std::size_t stride_M_,
std::size_t StrideN0_, std::size_t stride_N_) {
std::size_t StrideN1_) {
return HostTensorDescriptor( return HostTensorDescriptor(
std::vector<std::size_t>({batch_count_, M0_, M1_, N0_, N1_}), std::vector<std::size_t>({B0_, B1_, M_, N_}),
std::vector<std::size_t>( std::vector<std::size_t>({stride_B0_, stride_B1_, stride_M_, stride_N_}));
{M0_ * M1_ * N0_ * N1_, StrideM0_, StrideM1_, StrideN0_, StrideN1_}));
}; };
Tensor<CDataType> c_g_m0_m1_n0_n1_host_result(f_host_c_tensor_descriptor( Tensor<CDataType> c_g0_g1_m_n_host_result(
batch_count, M0, M1, N0, N1, stride_M0, stride_M1, stride_N0, stride_N1)); f_host_c_tensor_descriptor(G0, G1, M, N, stride_B0, stride_B1, stride_M, stride_N));
Tensor<CDataType> c_g_m0_m1_n0_n1_device_result(f_host_c_tensor_descriptor( Tensor<CDataType> c_g0_g1_m_n_device_result(
batch_count, M0, M1, N0, N1, stride_M0, stride_M1, stride_N0, stride_N1)); f_host_c_tensor_descriptor(G0, G1, M, N, stride_B0, stride_B1, stride_M, stride_N));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
std::cout << "c_g_m_n: " << c_g_m0_m1_n0_n1_host_result.mDesc << std::endl; std::cout << "c_g0_g1_m_n: " << c_g0_g1_m_n_host_result.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -165,8 +157,7 @@ int main(int argc, char* argv[]) ...@@ -165,8 +157,7 @@ int main(int argc, char* argv[])
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * DeviceMem c_device_buf(sizeof(CDataType) * c_g0_g1_m_n_device_result.mDesc.GetElementSpace());
c_g_m0_m1_n0_n1_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_g_m_k.mData.data()); a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data()); b_device_buf.ToDevice(b_g_k_n.mData.data());
...@@ -182,7 +173,12 @@ int main(int argc, char* argv[]) ...@@ -182,7 +173,12 @@ int main(int argc, char* argv[])
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
gemm_transpose_desc, M,
N,
K,
stride_A,
stride_B,
batched_gemm_c_permute_desc,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
...@@ -213,22 +209,36 @@ int main(int argc, char* argv[]) ...@@ -213,22 +209,36 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
c_device_buf.FromDevice(c_g_m0_m1_n0_n1_device_result.mData.data()); c_device_buf.FromDevice(c_g0_g1_m_n_device_result.mData.data());
auto ref_batched_gemm = ReferenceBatchedGemmCPermutationInstance{}; auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
auto ref_invoker = ref_batched_gemm.MakeInvoker(); auto ref_invoker = ref_batched_gemm.MakeInvoker();
auto ref_argument = ref_batched_gemm.MakeArgument(a_g_m_k, Tensor<CDataType> c_g_m_n_host_result = HostTensorDescriptor(
b_g_k_n, std::vector<std::size_t>({batch_count, M, N}), std::vector<std::size_t>({M * N, N, 1}));
c_g_m0_m1_n0_n1_host_result,
a_element_op, auto ref_argument = ref_batched_gemm.MakeArgument(
b_element_op, a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op);
c_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
pass = ck::utils::check_err(c_g_m0_m1_n0_n1_host_result.mData, for(int g0 = 0; g0 < G0; g0++)
c_g_m0_m1_n0_n1_device_result.mData, {
for(int g1 = 0; g1 < G1; g1++)
{
for(int m = 0; m < M; m++)
{
for(int n = 0; n < N; n++)
{
int g = g0 * G1 + g1;
c_g0_g1_m_n_host_result(g0, g1, m, n) = c_g_m_n_host_result(g, m, n);
}
}
}
}
pass = ck::utils::check_err(c_g0_g1_m_n_host_result.mData,
c_g0_g1_m_n_device_result.mData,
"Error: Incorrect results c"); "Error: Incorrect results c");
} }
......
...@@ -10,29 +10,29 @@ namespace device { ...@@ -10,29 +10,29 @@ namespace device {
struct BatchedGemmCPermuteDesc struct BatchedGemmCPermuteDesc
{ {
ck::index_t B0_, B1_, M_, N_; ck::index_t G0_, G1_, M_, N_;
ck::index_t stride_B0_, stride_B1_, stride_M_, stride_N_; ck::index_t stride_G0_, stride_G1_, stride_M_, stride_N_;
}; };
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
struct DeviceBatchedGemmCPermutate : public BaseOperator struct DeviceBatchedGemmCPermute : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, virtual std::unique_ptr<BaseArgument>
const void* p_b, MakeArgumentPointer(const void* p_a,
void* p_c, const void* p_b,
index_t M, void* p_c,
index_t N, index_t M,
index_t K, index_t N,
index_t stride_A, index_t K,
index_t stride_B, index_t stride_A,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc index_t stride_B,
AElementwiseOperation a_element_op, BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
BElementwiseOperation b_element_op, AElementwiseOperation a_element_op,
CElementwiseOperation c_element_op, BElementwiseOperation b_element_op,
ck::index_t BatchCount = 1) = 0; CElementwiseOperation c_element_op,
ck::index_t BatchCount = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
...@@ -40,10 +40,8 @@ struct DeviceBatchedGemmCPermutate : public BaseOperator ...@@ -40,10 +40,8 @@ struct DeviceBatchedGemmCPermutate : public BaseOperator
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
using DeviceBatchedGemmCPermutatePtr = using DeviceBatchedGemmCPermutePtr = std::unique_ptr<
std::unique_ptr<DeviceBatchedGemmCPermutate<AElementwiseOperation, DeviceBatchedGemmCPermute<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
BElementwiseOperation,
CElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -13,8 +13,6 @@ ...@@ -13,8 +13,6 @@
#include "ck/device_utility/device_prop.hpp" #include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp" #include "ck/device_utility/kernel_launch.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -41,7 +39,7 @@ namespace device { ...@@ -41,7 +39,7 @@ namespace device {
* *
* \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes. * \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemmCPermutate and GroupedGemm (and the corresponding GEMM fusion). * realize BatchedGemmCPermute and GroupedGemm (and the corresponding GEMM fusion).
* *
*/ */
template <typename GridwiseGemm, template <typename GridwiseGemm,
...@@ -160,10 +158,9 @@ template <typename ALayout, ...@@ -160,10 +158,9 @@ template <typename ALayout,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceBatchedGemmCPermutateXdl struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementwiseOperation,
: public DeviceBatchedGemmCPermutate<AElementwiseOperation, BElementwiseOperation,
BElementwiseOperation, CElementwiseOperation>
CElementwiseOperation>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -247,14 +244,10 @@ struct DeviceBatchedGemmCPermutateXdl ...@@ -247,14 +244,10 @@ struct DeviceBatchedGemmCPermutateXdl
} }
} }
static auto MakeCGridDescriptor_M_N(index_t M, static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t stride_M, index_t stride_N)
index_t N,
index_t stride_M,
index_t stride_N)
{ {
const auto c_grid_desc_m_n = [&]() { const auto c_grid_desc_m_n = [&]() {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(stride_M, stride_N));
make_tuple(M, N), make_tuple(stride_M, stride_N));
}(); }();
if constexpr(GemmSpec == GemmSpecialization::MNPadding) if constexpr(GemmSpec == GemmSpecialization::MNPadding)
...@@ -279,16 +272,53 @@ struct DeviceBatchedGemmCPermutateXdl ...@@ -279,16 +272,53 @@ struct DeviceBatchedGemmCPermutateXdl
} }
} }
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); static auto MakeEGridDescriptor_G0_G1_M_N(index_t G0,
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); index_t G1,
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1, 1, 1, 1)); index_t M,
index_t N,
index_t stride_G0,
index_t stride_G1,
index_t stride_M,
index_t stride_N)
{
const auto e_grid_desc_g0_g1_m_n = [&]() {
return make_naive_tensor_descriptor(
make_tuple(G0, G1, M, N), make_tuple(stride_G0, stride_G1, stride_M, stride_N));
}();
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
e_grid_desc_g0_g1_m_n,
make_tuple(make_pass_through_transform(G0),
make_pass_through_transform(G1),
make_right_pad_transform(M, PadM),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
}
else
{
return e_grid_desc_g0_g1_m_n;
}
}
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1));
using EGridDesc_G0_G1_M_N = decltype(MakeEGridDescriptor_G0_G1_M_N(1, 1, 1, 1, 1, 1, 1, 1));
struct ComputePtrOffsetOfStridedBatch struct ComputePtrOffsetOfStridedBatch
{ {
ComputePtrOffsetOfStridedBatch(index_t Batchstride_A, ComputePtrOffsetOfStridedBatch(index_t Batchstride_A,
index_t Batchstride_B, index_t Batchstride_B,
index_t BatchStrideC) EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n)
: Batchstride_A_(Batchstride_A), Batchstride_B_(Batchstride_B), BatchStrideC_(BatchStrideC) : Batchstride_A_(Batchstride_A),
Batchstride_B_(Batchstride_B),
e_grid_desc_g0_g1_m_n_(e_grid_desc_g0_g1_m_n)
{ {
} }
...@@ -304,13 +334,16 @@ struct DeviceBatchedGemmCPermutateXdl ...@@ -304,13 +334,16 @@ struct DeviceBatchedGemmCPermutateXdl
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideC_); const index_t G1 = e_grid_desc_g0_g1_m_n_.GetLength(I1);
index_t b0 = g_idx / G1;
index_t b1 = g_idx % G1;
return e_grid_desc_g0_g1_m_n_.CalculateOffset(make_multi_index(b0, b1, 0, 0));
} }
private: private:
index_t Batchstride_A_; index_t Batchstride_A_;
index_t Batchstride_B_; index_t Batchstride_B_;
index_t BatchStrideC_; EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
}; };
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle<
...@@ -383,20 +416,29 @@ struct DeviceBatchedGemmCPermutateXdl ...@@ -383,20 +416,29 @@ struct DeviceBatchedGemmCPermutateXdl
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
BatchCount_(BatchCount), BatchCount_(BatchCount),
a_grid_desc_k0_m_k1_{DeviceBatchedGemmCPermutateXdl::MakeAGridDescriptor_K0_M_K1( a_grid_desc_k0_m_k1_{
M, K, stride_A)}, DeviceBatchedGemmCPermuteXdl::MakeAGridDescriptor_K0_M_K1(M, K, stride_A)},
b_grid_desc_k0_n_k1_{DeviceBatchedGemmCPermutateXdl::MakeBGridDescriptor_K0_N_K1( b_grid_desc_k0_n_k1_{
K, N, stride_B)}, DeviceBatchedGemmCPermuteXdl::MakeBGridDescriptor_K0_N_K1(K, N, stride_B)},
c_grid_desc_m_n_{DeviceBatchedGemmCPermutateXdl::MakeCGridDescriptor_M_N( c_grid_desc_m_n_{DeviceBatchedGemmCPermuteXdl::MakeCGridDescriptor_M_N(
batched_gemm_c_permute_desc.M_,
batched_gemm_c_permute_desc.N_,
batched_gemm_c_permute_desc.stride_M_,
batched_gemm_c_permute_desc.stride_N_)},
e_grid_desc_g0_g1_m_n_{DeviceBatchedGemmCPermuteXdl::MakeEGridDescriptor_G0_G1_M_N(
batched_gemm_c_permute_desc.G0_,
batched_gemm_c_permute_desc.G1_,
batched_gemm_c_permute_desc.M_, batched_gemm_c_permute_desc.M_,
batched_gemm_c_permute_desc.N_, batched_gemm_c_permute_desc.N_,
batched_gemm_c_permute_desc.stride_G0_,
batched_gemm_c_permute_desc.stride_G1_,
batched_gemm_c_permute_desc.stride_M_, batched_gemm_c_permute_desc.stride_M_,
batched_gemm_c_permute_desc.stride_N_)}, batched_gemm_c_permute_desc.stride_N_)},
c_grid_desc_mblock_mperblock_nblock_nperblock{}, c_grid_desc_mblock_mperblock_nblock_nperblock{},
compute_ptr_offset_of_batch_{ compute_ptr_offset_of_batch_{
type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()), type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()), type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize())}, e_grid_desc_g0_g1_m_n_},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(c_grid_desc_m_n_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
...@@ -422,6 +464,7 @@ struct DeviceBatchedGemmCPermutateXdl ...@@ -422,6 +464,7 @@ struct DeviceBatchedGemmCPermutateXdl
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock; CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock;
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
...@@ -433,7 +476,7 @@ struct DeviceBatchedGemmCPermutateXdl ...@@ -433,7 +476,7 @@ struct DeviceBatchedGemmCPermutateXdl
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
using Argument = DeviceBatchedGemmCPermutateXdl::Argument; using Argument = DeviceBatchedGemmCPermuteXdl::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
...@@ -456,7 +499,7 @@ struct DeviceBatchedGemmCPermutateXdl ...@@ -456,7 +499,7 @@ struct DeviceBatchedGemmCPermutateXdl
arg.block_2_ctile_map_)) arg.block_2_ctile_map_))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseBatchedGemmCPermutate_km_kn_m0m1n0n1_xdlops_v2r3 has invalid " "wrong! GridwiseBatchedGemmCPermute_km_kn_m0m1n0n1_xdlops_v2r3 has invalid "
"setting"); "setting");
} }
...@@ -473,8 +516,8 @@ struct DeviceBatchedGemmCPermutateXdl ...@@ -473,8 +516,8 @@ struct DeviceBatchedGemmCPermutateXdl
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceBatchedGemmCPermutateXdl::AGridDesc_K0_M_K1>, remove_reference_t<DeviceBatchedGemmCPermuteXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceBatchedGemmCPermutateXdl::BGridDesc_K0_N_K1>, remove_reference_t<DeviceBatchedGemmCPermuteXdl::BGridDesc_K0_N_K1>,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -559,11 +602,11 @@ struct DeviceBatchedGemmCPermutateXdl ...@@ -559,11 +602,11 @@ struct DeviceBatchedGemmCPermutateXdl
return Argument{p_a, return Argument{p_a,
p_b, p_b,
p_c, p_c,
M, M,
N, N,
K, K,
stride_A, stride_A,
stride_B, stride_B,
batched_gemm_c_permute_desc, batched_gemm_c_permute_desc,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -574,28 +617,29 @@ struct DeviceBatchedGemmCPermutateXdl ...@@ -574,28 +617,29 @@ struct DeviceBatchedGemmCPermutateXdl
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, std::unique_ptr<BaseArgument>
const void* p_b, MakeArgumentPointer(const void* p_a,
void* p_c, const void* p_b,
index_t M, void* p_c,
index_t N, index_t M,
index_t K, index_t N,
index_t stride_A, index_t K,
index_t stride_B, index_t stride_A,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, index_t stride_B,
AElementwiseOperation a_element_op, BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
BElementwiseOperation b_element_op, AElementwiseOperation a_element_op,
CElementwiseOperation c_element_op, BElementwiseOperation b_element_op,
index_t BatchCount) override CElementwiseOperation c_element_op,
index_t BatchCount) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
M, M,
N, N,
K, K,
stride_A, stride_A,
stride_B, stride_B,
batched_gemm_c_permute_desc, batched_gemm_c_permute_desc,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -615,7 +659,7 @@ struct DeviceBatchedGemmCPermutateXdl ...@@ -615,7 +659,7 @@ struct DeviceBatchedGemmCPermutateXdl
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceBatchedGemmCPermutateXdl" str << "DeviceBatchedGemmCPermuteXdl"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment