"test/old-api/parsertests.cpp" did not exist on "f21456972ca627ef12a708096286fb25daa950f3"
Commit ed3c27cc authored by Chao Liu's avatar Chao Liu
Browse files

update gemm and batch gemm with e permute

parent dfbb659a
add_example_executable(example_batched_gemm_c_permute_xdl_fp16 batched_gemm_c_permute_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_e_permute_xdl_fp16 batched_gemm_e_permute_xdl_fp16.cpp)
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_c_permute_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_e_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -28,33 +28,33 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -28,33 +28,33 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = ck::half_t; using ADataType = ck::half_t;
using BDataType = ck::half_t; using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using CShuffleDataType = ck::half_t;
using EDataType = ck::half_t;
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::PassThrough;
// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// static constexpr auto MNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; // static constexpr auto MNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmCPermuteXdl using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmEPermuteXdl
//######| ALayout| BLayout| AData| BData| CData| AccData| A| B| C| GEMM| Num| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| ALayout| BLayout| AData| BData| AccData| CShuffle| EData| A| B| C| GEMM| Num| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | Type| Type| Type| Data| Type| Elementwise| Elementwise| Elementwise|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | Type| | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// < Row, Col, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, MNPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; // < Row, Col, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, MNPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
< Row, Col, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, MNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; < Row, Col, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, MNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on // clang-format on
using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
ReferenceBatchedGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>; ReferenceBatchedGemm<ADataType, BDataType, EDataType, AElementOp, BElementOp, CDEElementOp>;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -95,7 +95,7 @@ int main(int argc, char* argv[]) ...@@ -95,7 +95,7 @@ int main(int argc, char* argv[])
} }
// GEMM shape // GEMM shape
ck::tensor_operation::device::BatchedGemmCPermuteDesc batched_gemm_c_permute_desc{ ck::tensor_operation::device::BatchedGemmEPermuteDesc batched_gemm_e_permute_desc{
G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N}; G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N};
auto f_host_tensor_descriptor = [](std::size_t batch_count_, auto f_host_tensor_descriptor = [](std::size_t batch_count_,
...@@ -118,7 +118,7 @@ int main(int argc, char* argv[]) ...@@ -118,7 +118,7 @@ int main(int argc, char* argv[])
Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(batch_count, M, K, stride_A, ALayout{})); Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(batch_count, M, K, stride_A, ALayout{}));
Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(batch_count, K, N, stride_B, BLayout{})); Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(batch_count, K, N, stride_B, BLayout{}));
auto f_host_c_tensor_descriptor = [](std::size_t G0_, auto f_host_e_tensor_descriptor = [](std::size_t G0_,
std::size_t G1_, std::size_t G1_,
std::size_t M_, std::size_t M_,
std::size_t N_, std::size_t N_,
...@@ -131,15 +131,15 @@ int main(int argc, char* argv[]) ...@@ -131,15 +131,15 @@ int main(int argc, char* argv[])
std::vector<std::size_t>({stride_G0_, stride_G1_, stride_M_, stride_N_})); std::vector<std::size_t>({stride_G0_, stride_G1_, stride_M_, stride_N_}));
}; };
Tensor<CDataType> c_g0_g1_m_n_host_result( Tensor<EDataType> e_g0_g1_m_n_host_result(
f_host_c_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N)); f_host_e_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N));
Tensor<CDataType> c_g0_g1_m_n_device_result( Tensor<EDataType> e_g0_g1_m_n_device_result(
f_host_c_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N)); f_host_e_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
std::cout << "c_g0_g1_m_n: " << c_g0_g1_m_n_host_result.mDesc << std::endl; std::cout << "e_g0_g1_m_n: " << e_g0_g1_m_n_host_result.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -156,15 +156,15 @@ int main(int argc, char* argv[]) ...@@ -156,15 +156,15 @@ int main(int argc, char* argv[])
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * DeviceMem e_device_buf(sizeof(EDataType) *
c_g0_g1_m_n_device_result.mDesc.GetElementSpaceSize()); e_g0_g1_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_g_m_k.mData.data()); a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data()); b_device_buf.ToDevice(b_g_k_n.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto cde_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
...@@ -172,16 +172,16 @@ int main(int argc, char* argv[]) ...@@ -172,16 +172,16 @@ int main(int argc, char* argv[])
// do GEMM // do GEMM
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<EDataType*>(e_device_buf.GetDeviceBuffer()),
M, M,
N, N,
K, K,
stride_A, stride_A,
stride_B, stride_B,
batched_gemm_c_permute_desc, batched_gemm_e_permute_desc,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, cde_element_op,
batch_count); batch_count);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
...@@ -196,7 +196,7 @@ int main(int argc, char* argv[]) ...@@ -196,7 +196,7 @@ int main(int argc, char* argv[])
std::size_t flop = std::size_t(2) * batch_count * M * N * K; std::size_t flop = std::size_t(2) * batch_count * M * N * K;
std::size_t num_btype = sizeof(ADataType) * batch_count * M * K + std::size_t num_btype = sizeof(ADataType) * batch_count * M * K +
sizeof(BDataType) * batch_count * K * N + sizeof(BDataType) * batch_count * K * N +
sizeof(CDataType) * batch_count * M * N; sizeof(EDataType) * batch_count * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -209,16 +209,16 @@ int main(int argc, char* argv[]) ...@@ -209,16 +209,16 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
c_device_buf.FromDevice(c_g0_g1_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_g0_g1_m_n_device_result.mData.data());
auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
auto ref_invoker = ref_batched_gemm.MakeInvoker(); auto ref_invoker = ref_batched_gemm.MakeInvoker();
Tensor<CDataType> c_g_m_n_host_result = HostTensorDescriptor( Tensor<EDataType> c_g_m_n_host_result = HostTensorDescriptor(
std::vector<std::size_t>({batch_count, M, N}), std::vector<std::size_t>({M * N, N, 1})); std::vector<std::size_t>({batch_count, M, N}), std::vector<std::size_t>({M * N, N, 1}));
auto ref_argument = ref_batched_gemm.MakeArgument( auto ref_argument = ref_batched_gemm.MakeArgument(
a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op); a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, cde_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
...@@ -231,14 +231,15 @@ int main(int argc, char* argv[]) ...@@ -231,14 +231,15 @@ int main(int argc, char* argv[])
for(int n = 0; n < N; n++) for(int n = 0; n < N; n++)
{ {
int g = g0 * G1 + g1; int g = g0 * G1 + g1;
c_g0_g1_m_n_host_result(g0, g1, m, n) = c_g_m_n_host_result(g, m, n);
e_g0_g1_m_n_host_result(g0, g1, m, n) = c_g_m_n_host_result(g, m, n);
} }
} }
} }
} }
pass = ck::utils::check_err(c_g0_g1_m_n_host_result.mData, pass = ck::utils::check_err(e_g0_g1_m_n_host_result.mData,
c_g0_g1_m_n_device_result.mData, e_g0_g1_m_n_device_result.mData,
"Error: Incorrect results c"); "Error: Incorrect results c");
} }
......
add_example_executable(example_gemm_bias_c_permute_xdl_fp16 gemm_bias_c_permute_xdl_fp16.cpp)
add_example_executable(example_gemm_bias_e_permute_xdl_fp16 gemm_bias_e_permute_xdl_fp16.cpp)
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_bias_e_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
...@@ -49,7 +49,7 @@ using CDEElementOp = Add; ...@@ -49,7 +49,7 @@ using CDEElementOp = Add;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off // clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmBiasCPermute_Xdl using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmBiasEPermute_Xdl
//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
......
...@@ -38,8 +38,8 @@ add_subdirectory(20_convnd_bwd_weight) ...@@ -38,8 +38,8 @@ add_subdirectory(20_convnd_bwd_weight)
add_subdirectory(21_gemm_layernorm) add_subdirectory(21_gemm_layernorm)
add_subdirectory(22_cgemm) add_subdirectory(22_cgemm)
add_subdirectory(23_softmax) add_subdirectory(23_softmax)
add_subdirectory(24_batched_gemm_c_permute) add_subdirectory(24_batched_gemm_e_permute)
add_subdirectory(25_gemm_bias_c_permute) add_subdirectory(25_gemm_bias_e_permute)
add_subdirectory(26_contraction) add_subdirectory(26_contraction)
add_subdirectory(27_layernorm) add_subdirectory(27_layernorm)
add_subdirectory(28_group_convnd_fwd_bias_relu) add_subdirectory(28_group_convnd_fwd_bias_relu)
...@@ -8,7 +8,7 @@ namespace ck { ...@@ -8,7 +8,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
struct BatchedGemmCPermuteDesc struct BatchedGemmEPermuteDesc
{ {
ck::index_t G0_, G1_, M_, N_; ck::index_t G0_, G1_, M_, N_;
ck::index_t stride_G0_, stride_G1_, stride_M_, stride_N_; ck::index_t stride_G0_, stride_G1_, stride_M_, stride_N_;
...@@ -16,33 +16,27 @@ struct BatchedGemmCPermuteDesc ...@@ -16,33 +16,27 @@ struct BatchedGemmCPermuteDesc
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CDEElementwiseOperation>
struct DeviceBatchedGemmCPermute : public BaseOperator struct DeviceBatchedGemmEPermute : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
void* p_c, void* p_e,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t stride_A, index_t stride_A,
index_t stride_B, index_t stride_B,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CDEElementwiseOperation cde_element_op,
ck::index_t BatchCount) = 0; ck::index_t BatchCount) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceBatchedGemmCPermutePtr = std::unique_ptr<
DeviceBatchedGemmCPermute<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -7,8 +7,9 @@ ...@@ -7,8 +7,9 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_c_permute.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_e_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -24,9 +25,10 @@ namespace device { ...@@ -24,9 +25,10 @@ namespace device {
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p * strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
* limitations. * limitations.
* *
* \tparam Block2CTileMap Block2CTileMap::CalculateBottomIndex() takes in id of a workgroup and * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see * returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
* *
...@@ -37,40 +39,40 @@ namespace device { ...@@ -37,40 +39,40 @@ namespace device {
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
* pointer offset into \p ComputePtrOffsetOfStridedBatch. * pointer offset into \p ComputePtrOffsetOfStridedBatch.
* *
* \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes. * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemmCPermute and GroupedGemm (and the corresponding GEMM fusion). * realize BatchedGemmCPermute and GroupedGemm (and the corresponding GEMM fusion).
* *
*/ */
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename ABDataType,
typename FloatC, typename EDataType,
typename AGridDesc_K0_M_K1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_K0_N_K1, typename BGridDesc_BK0_N_BK1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CDEElementwiseOperation,
typename ComputePtrOffsetOfBatch, typename ComputePtrOffsetOfBatch,
typename Block2CTileMap, typename Block2ETileMap,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_c_permute_xdl(const FloatAB* __restrict__ p_a_grid, kernel_batched_gemm_e_permute_xdl(const ABDataType* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, EDataType* __restrict__ p_e_grid,
const index_t batch_count, const index_t batch_count,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op, const CDEElementwiseOperation cde_element_op,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
...@@ -81,40 +83,37 @@ __global__ void ...@@ -81,40 +83,37 @@ __global__ void
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
ck::Tuple<>{}, ck::Tuple<>{},
p_c_grid + c_batch_offset, p_e_grid + e_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, cde_element_op,
a_grid_desc_k0_m_k1, a_grid_desc_ak0_m_ak1,
b_grid_desc_k0_n_k1, b_grid_desc_bk0_n_bk1,
ck::StaticallyIndexedArray< ck::Tuple<>{},
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, e_grid_desc_mblock_mperblock_nblock_nperblock,
0>{}, block_2_etile_map);
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_c_grid; ignore = p_e_grid;
ignore = batch_count; ignore = batch_count;
ignore = a_grid_desc_k0_m_k1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_k0_n_k1; ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = cde_element_op;
ignore = compute_ptr_offset_of_batch; ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map; ignore = block_2_etile_map;
#endif #endif
} }
...@@ -122,51 +121,57 @@ template <typename ALayout, ...@@ -122,51 +121,57 @@ template <typename ALayout,
typename BLayout, typename BLayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename CDataType,
typename AccDataType, typename AccDataType,
typename CShuffleDataType,
typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CDEElementwiseOperation,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
ck::index_t NumPrefetch, index_t NumPrefetch,
ck::index_t BlockSize, index_t BlockSize,
ck::index_t MPerBlock, index_t MPerBlock,
ck::index_t NPerBlock, index_t NPerBlock,
ck::index_t KPerBlock, index_t KPerBlock,
ck::index_t AK1, index_t AK1,
ck::index_t BK1, index_t BK1,
ck::index_t MPerXDL, index_t MPerXDL,
ck::index_t NPerXDL, index_t NPerXDL,
ck::index_t MXdlPerWave, index_t MXdlPerWave,
ck::index_t NXdlPerWave, index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1, typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1, index_t ABlockTransferDstScalarPerVector_K1,
bool ABlockLdsAddExtraM, index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1, typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder, typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1, index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN, index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementwiseOperation, struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation> CDEElementwiseOperation>
{ {
using DeviceOp = DeviceBatchedGemmEPermuteXdl;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
{ {
const auto a_grid_desc_mraw_kraw = [&]() { const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
...@@ -181,95 +186,10 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -181,95 +186,10 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
} }
}(); }();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
} }
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
{ {
const auto b_grid_desc_nraw_kraw = [&]() { const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
...@@ -284,142 +204,16 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -284,142 +204,16 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
} }
}(); }();
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
} }
static auto static auto
MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N) MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N)
{ {
const auto c_grid_desc_mraw_nraw = [&]() { const auto e_grid_desc_mraw_nraw =
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), make_tuple(stride_M, stride_N));
make_tuple(stride_M, stride_N));
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding || return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
} }
static auto MakeEGridDescriptor_G0_G1_M_N(index_t G0, static auto MakeEGridDescriptor_G0_G1_M_N(index_t G0,
...@@ -489,9 +283,9 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -489,9 +283,9 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
} }
} }
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1)); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1, 1));
using EGridDesc_G0_G1_M_N = decltype(MakeEGridDescriptor_G0_G1_M_N(1, 1, 1, 1, 1, 1, 1, 1)); using EGridDesc_G0_G1_M_N = decltype(MakeEGridDescriptor_G0_G1_M_N(1, 1, 1, 1, 1, 1, 1, 1));
struct ComputePtrOffsetOfStridedBatch struct ComputePtrOffsetOfStridedBatch
...@@ -529,19 +323,20 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -529,19 +323,20 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_; EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
}; };
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, // CShuffleDataType, CShuffleDataType,
ck::Tuple<>, // DsDataType, ck::Tuple<>, // DsDataType,
CDataType, // EDataType, EDataType, // EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, AGridDesc_M_K,
BGridDesc_K0_N_K1, BGridDesc_N_K,
CGridDesc_M_N, Tuple<>,
EGridDesc_M_N,
NumPrefetch, NumPrefetch,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -560,7 +355,7 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -560,7 +355,7 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun, false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM, ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -568,118 +363,135 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -568,118 +363,135 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun, false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN, BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle, CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype( using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using Block2CTileMap = typename GridwiseGemm::DefaultBlock2ETileMap; using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}));
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const ADataType* p_a_grid, Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid, const BDataType* p_b_grid,
CDataType* p_c_grid, EDataType* p_e_grid,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t stride_A, index_t stride_A,
index_t stride_B, index_t stride_B,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CDEElementwiseOperation cde_element_op,
index_t BatchCount) index_t BatchCount)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_e_grid_{p_e_grid},
BatchCount_(BatchCount), BatchCount_(BatchCount),
a_grid_desc_k0_m_k1_{ a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(M, K, stride_A)},
DeviceBatchedGemmCPermuteXdl::MakeAGridDescriptor_AK0_M_AK1(M, K, stride_A)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(K, N, stride_B)},
b_grid_desc_k0_n_k1_{ e_grid_desc_m_n_{
DeviceBatchedGemmCPermuteXdl::MakeBGridDescriptor_BK0_N_BK1(K, N, stride_B)}, DeviceOp::MakeEGridDescriptor_M_N(batched_gemm_e_permute_desc.M_,
c_grid_desc_m_n_{DeviceBatchedGemmCPermuteXdl::MakeCGridDescriptor_M_N( batched_gemm_e_permute_desc.N_,
batched_gemm_c_permute_desc.M_, batched_gemm_e_permute_desc.stride_M_,
batched_gemm_c_permute_desc.N_, batched_gemm_e_permute_desc.stride_N_)},
batched_gemm_c_permute_desc.stride_M_, a_grid_desc_ak0_m_ak1_{
batched_gemm_c_permute_desc.stride_N_)}, GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
e_grid_desc_g0_g1_m_n_{DeviceBatchedGemmCPermuteXdl::MakeEGridDescriptor_G0_G1_M_N( b_grid_desc_bk0_n_bk1_{
batched_gemm_c_permute_desc.G0_, GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
batched_gemm_c_permute_desc.G1_, e_grid_desc_mblock_mperblock_nblock_nperblock{},
batched_gemm_c_permute_desc.M_, e_grid_desc_g0_g1_m_n_{
batched_gemm_c_permute_desc.N_, DeviceOp::MakeEGridDescriptor_G0_G1_M_N(batched_gemm_e_permute_desc.G0_,
batched_gemm_c_permute_desc.stride_G0_, batched_gemm_e_permute_desc.G1_,
batched_gemm_c_permute_desc.stride_G1_, batched_gemm_e_permute_desc.M_,
batched_gemm_c_permute_desc.stride_M_, batched_gemm_e_permute_desc.N_,
batched_gemm_c_permute_desc.stride_N_)}, batched_gemm_e_permute_desc.stride_G0_,
c_grid_desc_mblock_mperblock_nblock_nperblock{}, batched_gemm_e_permute_desc.stride_G1_,
batched_gemm_e_permute_desc.stride_M_,
batched_gemm_e_permute_desc.stride_N_)},
compute_ptr_offset_of_batch_{ compute_ptr_offset_of_batch_{
type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()), type_convert<index_t>(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()), type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()),
e_grid_desc_g0_g1_m_n_}, e_grid_desc_g0_g1_m_n_},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(c_grid_desc_m_n_)}, block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} cde_element_op_{cde_element_op}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, b_grid_desc_n_k_,
b_grid_desc_k0_n_k1_, ck::Tuple<>{},
c_grid_desc_m_n_, e_grid_desc_m_n_,
block_2_ctile_map_)) block_2_etile_map_))
{ {
c_grid_desc_mblock_mperblock_nblock_nperblock = e_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_); e_grid_desc_m_n_);
} }
} }
void Print() const
{
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
std::cout << "C[M, N]: " << e_grid_desc_m_n_ << std::endl;
}
// private: // private:
// pointers
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; EDataType* p_e_grid_;
// batch count
index_t BatchCount_; index_t BatchCount_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; // tensor descriptors for problem definiton
CGridDesc_M_N c_grid_desc_m_n_; AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
EGridDesc_M_N e_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock;
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_; EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock;
// for calculating Batch offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
Block2CTileMap block_2_ctile_map_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
// element-wise op
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_; CDEElementwiseOperation cde_element_op_;
}; };
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
using Argument = DeviceBatchedGemmCPermuteXdl::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
{ if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) arg.b_grid_desc_n_k_,
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " ck::Tuple<>{},
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; arg.e_grid_desc_m_n_,
arg.block_2_etile_map_))
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseBatchedGemmCPermute_km_kn_m0m1n0n1_xdlops_v2r3 has invalid " "wrong! GridwiseBatchedGemmCPermute_km_kn_m0m1n0n1_xdlops_v2r3 has invalid "
...@@ -687,26 +499,24 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -687,26 +499,24 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
} }
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.BatchCount_;
const auto K = const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_gemm_c_permute_xdl< const auto kernel = kernel_batched_gemm_e_permute_xdl<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, EDataType,
remove_reference_t<DeviceBatchedGemmCPermuteXdl::AGridDesc_K0_M_K1>, remove_reference_t<DeviceOp::AGridDesc_AK0_M_AK1>,
remove_reference_t<DeviceBatchedGemmCPermuteXdl::BGridDesc_K0_N_K1>, remove_reference_t<DeviceOp::BGridDesc_BK0_N_BK1>,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CDEElementwiseOperation,
ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>, remove_reference_t<Block2ETileMap>,
has_main_k_block_loop_>; has_main_k_block_loop_>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
...@@ -716,28 +526,26 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -716,28 +526,26 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_e_grid_,
arg.BatchCount_, arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock, arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.cde_element_op_,
arg.compute_ptr_offset_of_batch_, arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_); arg.block_2_etile_map_);
}; };
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}); return launch_kernel(integral_constant<bool, true>{});
} }
else else
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}); return launch_kernel(integral_constant<bool, false>{});
} }
return ave_time;
} }
// polymorphic // polymorphic
...@@ -756,10 +564,11 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -756,10 +564,11 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_k0_n_k1_, arg.b_grid_desc_n_k_,
arg.c_grid_desc_m_n_, ck::Tuple<>{},
arg.block_2_ctile_map_); arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
} }
// polymorphic // polymorphic
...@@ -770,30 +579,30 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -770,30 +579,30 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b, const BDataType* p_b,
CDataType* p_c, EDataType* p_e,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t stride_A, index_t stride_A,
index_t stride_B, index_t stride_B,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CDEElementwiseOperation cde_element_op,
index_t BatchCount) index_t BatchCount)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
p_c, p_e,
M, M,
N, N,
K, K,
stride_A, stride_A,
stride_B, stride_B,
batched_gemm_c_permute_desc, batched_gemm_e_permute_desc,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, cde_element_op,
BatchCount}; BatchCount};
} }
...@@ -803,30 +612,30 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -803,30 +612,30 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
void* p_c, void* p_e,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t stride_A, index_t stride_A,
index_t stride_B, index_t stride_B,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc, BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CDEElementwiseOperation cde_element_op,
index_t BatchCount) override index_t BatchCount) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c), static_cast<EDataType*>(p_e),
M, M,
N, N,
K, K,
stride_A, stride_A,
stride_B, stride_B,
batched_gemm_c_permute_desc, batched_gemm_e_permute_desc,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, cde_element_op,
BatchCount); BatchCount);
} }
...@@ -842,7 +651,7 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw ...@@ -842,7 +651,7 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceBatchedGemmCPermuteXdl" str << "DeviceBatchedGemmEPermuteXdl"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -46,12 +46,6 @@ struct DeviceGemmBiasCPermute : public BaseOperator ...@@ -46,12 +46,6 @@ struct DeviceGemmBiasCPermute : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGemmBiasCPermutePtr = std::unique_ptr<
DeviceGemmBiasCPermute<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -10,8 +10,9 @@ ...@@ -10,8 +10,9 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_c_permute.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -35,7 +36,7 @@ __global__ void ...@@ -35,7 +36,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_bias_c_permute(const FloatAB* __restrict__ p_a_grid, kernel_gemm_bias_e_permute(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatDsPointer p_ds_grid, FloatDsPointer p_ds_grid,
FloatE* __restrict__ p_e_grid, FloatE* __restrict__ p_e_grid,
...@@ -99,7 +100,7 @@ template <typename ALayout, ...@@ -99,7 +100,7 @@ template <typename ALayout,
typename CDELayout, typename CDELayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename GemmAccDataType, typename AccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename DDataType, typename DDataType,
typename EDataType, typename EDataType,
...@@ -124,33 +125,36 @@ template <typename ALayout, ...@@ -124,33 +125,36 @@ template <typename ALayout,
index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1, index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM, index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1, typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder, typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1, index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN, index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOperation, struct DeviceGemmBiasEPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation> CDEElementwiseOperation>
{ {
using DeviceOp = DeviceGemmBiasCPermute_Xdl; using DeviceOp = DeviceGemmBiasEPermute_Xdl;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr index_t NumDTensor = I1; static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) static constexpr index_t NumDTensor = 1;
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
{ {
const auto a_grid_desc_mraw_kraw = [&]() { const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
...@@ -165,95 +169,10 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -165,95 +169,10 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
} }
}(); }();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
} }
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
{ {
const auto b_grid_desc_nraw_kraw = [&]() { const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
...@@ -268,92 +187,7 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -268,92 +187,7 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
} }
}(); }();
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
} }
static auto MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1 d_e_grid_desc) static auto MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1 d_e_grid_desc)
...@@ -370,73 +204,32 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -370,73 +204,32 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
index_t stride_N0 = d_e_grid_desc.stride_N0_; index_t stride_N0 = d_e_grid_desc.stride_N0_;
index_t stride_N1 = d_e_grid_desc.stride_N1_; index_t stride_N1 = d_e_grid_desc.stride_N1_;
const auto MRaw = M0 * M1 * M2; const auto e_grid_desc_mraw_nraw = [&]() {
const auto NRaw = N0 * N1; const auto e_grid_desc_m0_m1_m2_n0_n1 = make_naive_tensor_descriptor(
const auto c_grid_desc_mraw_nraw = [&]() {
const auto c_grid_desc_m0_m1_m2_n0_n1 = make_naive_tensor_descriptor(
make_tuple(M0, M1, M2, N0, N1), make_tuple(M0, M1, M2, N0, N1),
make_tuple(stride_M0, stride_M1, stride_M2, stride_N0, stride_N1)); make_tuple(stride_M0, stride_M1, stride_M2, stride_N0, stride_N1));
return transform_tensor_descriptor( return transform_tensor_descriptor(
c_grid_desc_m0_m1_m2_n0_n1, e_grid_desc_m0_m1_m2_n0_n1,
make_tuple(make_merge_transform(make_tuple(M0, M1, M2)), make_tuple(make_merge_transform(make_tuple(M0, M1, M2)),
make_merge_transform(make_tuple(N0, N1))), make_merge_transform(make_tuple(N0, N1))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}), make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
}(); }();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1{})); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1{}));
using DsGridDesc_M_N = Tuple<EGridDesc_M_N>;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
GemmAccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
ck::Tuple<DDataType>, ck::Tuple<DDataType>,
EDataType, EDataType,
...@@ -444,8 +237,9 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -444,8 +237,9 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1, AGridDesc_M_K,
BGridDesc_BK0_N_BK1, BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N, EGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
...@@ -480,6 +274,13 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -480,6 +274,13 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -499,12 +300,17 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -499,12 +300,17 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
CDEElementwiseOperation cde_element_op) CDEElementwiseOperation cde_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)}, : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)}, p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{}, // FIXME p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)}, p_e_grid_{static_cast<EDataType*>(p_e_grid)},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_grid_desc)}, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_grid_desc)},
a_grid_desc_ak0_m_ak1_{
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
...@@ -522,8 +328,16 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -522,8 +328,16 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
} }
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, // populate pointer, desc for Ds
b_grid_desc_bk0_n_bk1_, // D pointer
p_ds_grid_(I0) = static_cast<const DDataType*>(p_d_grid);
// D desc
ds_grid_desc_m_n_(I0) = DeviceOp::MakeEGridDescriptor_M_N(d_grid_desc);
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_, e_grid_desc_m_n_,
block_2_etile_map_)) block_2_etile_map_))
{ {
...@@ -531,32 +345,37 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -531,32 +345,37 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_); e_grid_desc_m_n_);
p_ds_grid_(I0) = static_cast<const DDataType*>(p_d_grid);
const auto d_grid_desc_m_n = DeviceOp::MakeEGridDescriptor_M_N(d_grid_desc);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(I0) = ds_grid_desc_mblock_mperblock_nblock_nperblock_(I0) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
d_grid_desc_m_n); ds_grid_desc_m_n_[I0]);
} }
} }
// private: // private:
// pointers
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_; typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_; EDataType* p_e_grid_;
// tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
StaticallyIndexedArray< typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, ds_grid_desc_mblock_mperblock_nblock_nperblock_;
NumDTensor>
ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different
// type from E
EGridDesc_M_N e_grid_desc_m_n_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_; e_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
// element-wise op
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
...@@ -569,8 +388,9 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -569,8 +388,9 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.block_2_etile_map_)) arg.block_2_etile_map_))
{ {
...@@ -586,7 +406,7 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -586,7 +406,7 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_gemm_bias_c_permute< const auto kernel = kernel_gemm_bias_e_permute<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer, typename GridwiseGemm::DsGridPointer,
...@@ -596,9 +416,7 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -596,9 +416,7 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
CDEElementwiseOperation, CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
ck::StaticallyIndexedArray< typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2ETileMap, typename GridwiseGemm::DefaultBlock2ETileMap,
has_main_loop>; has_main_loop>;
...@@ -622,18 +440,14 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -622,18 +440,14 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
arg.block_2_etile_map_); arg.block_2_etile_map_);
}; };
float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}); return launch_kernel(integral_constant<bool, true>{});
} }
else else
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}); return launch_kernel(integral_constant<bool, false>{});
} }
return ave_time;
} }
// polymorphic // polymorphic
...@@ -651,8 +465,9 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -651,8 +465,9 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
return false; return false;
} }
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.block_2_etile_map_); arg.block_2_etile_map_);
} }
...@@ -741,7 +556,7 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -741,7 +556,7 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGemmBiasCPermute_Xdl" str << "DeviceGemmBiasEPermute_Xdl"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -205,12 +205,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -205,12 +205,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{ {
const auto e_grid_desc_mraw_nraw = [&]() { const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1)); make_tuple(StrideE, I1));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE)); make_tuple(I1, StrideE));
...@@ -329,7 +329,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -329,7 +329,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op} cde_element_op_{cde_element_op}
{ {
// populate pointer, batch stride, desc for Ds // populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>; using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment