Commit ed3c27cc authored by Chao Liu's avatar Chao Liu
Browse files

update gemm and batch gemm with e permute

parent dfbb659a
add_example_executable(example_batched_gemm_c_permute_xdl_fp16 batched_gemm_c_permute_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_e_permute_xdl_fp16 batched_gemm_e_permute_xdl_fp16.cpp)
......@@ -6,7 +6,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_c_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_e_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
......@@ -26,35 +26,35 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = ck::half_t;
using EDataType = ck::half_t;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::PassThrough;
// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// static constexpr auto MNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmCPermuteXdl
//######| ALayout| BLayout| AData| BData| CData| AccData| A| B| C| GEMM| Num| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// < Row, Col, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, MNPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
< Row, Col, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, MNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmEPermuteXdl
//######| ALayout| BLayout| AData| BData| AccData| CShuffle| EData| A| B| C| GEMM| Num| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | Type| Type| Type| Data| Type| Elementwise| Elementwise| Elementwise|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | Type| | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// < Row, Col, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, MNPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
< Row, Col, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, MNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
ReferenceBatchedGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
ReferenceBatchedGemm<ADataType, BDataType, EDataType, AElementOp, BElementOp, CDEElementOp>;
int main(int argc, char* argv[])
{
......@@ -95,7 +95,7 @@ int main(int argc, char* argv[])
}
// GEMM shape
ck::tensor_operation::device::BatchedGemmCPermuteDesc batched_gemm_c_permute_desc{
ck::tensor_operation::device::BatchedGemmEPermuteDesc batched_gemm_e_permute_desc{
G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N};
auto f_host_tensor_descriptor = [](std::size_t batch_count_,
......@@ -118,7 +118,7 @@ int main(int argc, char* argv[])
Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(batch_count, M, K, stride_A, ALayout{}));
Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(batch_count, K, N, stride_B, BLayout{}));
auto f_host_c_tensor_descriptor = [](std::size_t G0_,
auto f_host_e_tensor_descriptor = [](std::size_t G0_,
std::size_t G1_,
std::size_t M_,
std::size_t N_,
......@@ -131,15 +131,15 @@ int main(int argc, char* argv[])
std::vector<std::size_t>({stride_G0_, stride_G1_, stride_M_, stride_N_}));
};
Tensor<CDataType> c_g0_g1_m_n_host_result(
f_host_c_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N));
Tensor<EDataType> e_g0_g1_m_n_host_result(
f_host_e_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N));
Tensor<CDataType> c_g0_g1_m_n_device_result(
f_host_c_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N));
Tensor<EDataType> e_g0_g1_m_n_device_result(
f_host_e_tensor_descriptor(G0, G1, M, N, stride_G0, stride_G1, stride_M, stride_N));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
std::cout << "c_g0_g1_m_n: " << c_g0_g1_m_n_host_result.mDesc << std::endl;
std::cout << "e_g0_g1_m_n: " << e_g0_g1_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
......@@ -156,15 +156,15 @@ int main(int argc, char* argv[])
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) *
c_g0_g1_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) *
e_g0_g1_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
......@@ -172,16 +172,16 @@ int main(int argc, char* argv[])
// do GEMM
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<EDataType*>(e_device_buf.GetDeviceBuffer()),
M,
N,
K,
stride_A,
stride_B,
batched_gemm_c_permute_desc,
batched_gemm_e_permute_desc,
a_element_op,
b_element_op,
c_element_op,
cde_element_op,
batch_count);
if(!gemm.IsSupportedArgument(argument))
......@@ -196,7 +196,7 @@ int main(int argc, char* argv[])
std::size_t flop = std::size_t(2) * batch_count * M * N * K;
std::size_t num_btype = sizeof(ADataType) * batch_count * M * K +
sizeof(BDataType) * batch_count * K * N +
sizeof(CDataType) * batch_count * M * N;
sizeof(EDataType) * batch_count * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......@@ -209,16 +209,16 @@ int main(int argc, char* argv[])
if(do_verification)
{
c_device_buf.FromDevice(c_g0_g1_m_n_device_result.mData.data());
e_device_buf.FromDevice(e_g0_g1_m_n_device_result.mData.data());
auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
auto ref_invoker = ref_batched_gemm.MakeInvoker();
Tensor<CDataType> c_g_m_n_host_result = HostTensorDescriptor(
Tensor<EDataType> c_g_m_n_host_result = HostTensorDescriptor(
std::vector<std::size_t>({batch_count, M, N}), std::vector<std::size_t>({M * N, N, 1}));
auto ref_argument = ref_batched_gemm.MakeArgument(
a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op);
a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, cde_element_op);
ref_invoker.Run(ref_argument);
......@@ -230,15 +230,16 @@ int main(int argc, char* argv[])
{
for(int n = 0; n < N; n++)
{
int g = g0 * G1 + g1;
c_g0_g1_m_n_host_result(g0, g1, m, n) = c_g_m_n_host_result(g, m, n);
int g = g0 * G1 + g1;
e_g0_g1_m_n_host_result(g0, g1, m, n) = c_g_m_n_host_result(g, m, n);
}
}
}
}
pass = ck::utils::check_err(c_g0_g1_m_n_host_result.mData,
c_g0_g1_m_n_device_result.mData,
pass = ck::utils::check_err(e_g0_g1_m_n_host_result.mData,
e_g0_g1_m_n_device_result.mData,
"Error: Incorrect results c");
}
......
add_example_executable(example_gemm_bias_c_permute_xdl_fp16 gemm_bias_c_permute_xdl_fp16.cpp)
add_example_executable(example_gemm_bias_e_permute_xdl_fp16 gemm_bias_e_permute_xdl_fp16.cpp)
......@@ -9,7 +9,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_c_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_e_permute_xdl.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
......@@ -49,7 +49,7 @@ using CDEElementOp = Add;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmBiasCPermute_Xdl
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmBiasEPermute_Xdl
//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
......
......@@ -38,8 +38,8 @@ add_subdirectory(20_convnd_bwd_weight)
add_subdirectory(21_gemm_layernorm)
add_subdirectory(22_cgemm)
add_subdirectory(23_softmax)
add_subdirectory(24_batched_gemm_c_permute)
add_subdirectory(25_gemm_bias_c_permute)
add_subdirectory(24_batched_gemm_e_permute)
add_subdirectory(25_gemm_bias_e_permute)
add_subdirectory(26_contraction)
add_subdirectory(27_layernorm)
add_subdirectory(28_group_convnd_fwd_bias_relu)
......@@ -8,7 +8,7 @@ namespace ck {
namespace tensor_operation {
namespace device {
struct BatchedGemmCPermuteDesc
struct BatchedGemmEPermuteDesc
{
ck::index_t G0_, G1_, M_, N_;
ck::index_t stride_G0_, stride_G1_, stride_M_, stride_N_;
......@@ -16,33 +16,27 @@ struct BatchedGemmCPermuteDesc
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceBatchedGemmCPermute : public BaseOperator
typename CDEElementwiseOperation>
struct DeviceBatchedGemmEPermute : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
CDEElementwiseOperation cde_element_op,
ck::index_t BatchCount) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceBatchedGemmCPermutePtr = std::unique_ptr<
DeviceBatchedGemmCPermute<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -7,8 +7,9 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_c_permute.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_e_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -24,9 +25,10 @@ namespace device {
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
* limitations.
*
* \tparam Block2CTileMap Block2CTileMap::CalculateBottomIndex() takes in id of a workgroup and
* \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
......@@ -37,40 +39,40 @@ namespace device {
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes.
* \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemmCPermute and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename ABDataType,
typename EDataType,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename CDEElementwiseOperation,
typename ComputePtrOffsetOfBatch,
typename Block2CTileMap,
typename Block2ETileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_c_permute_xdl(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
kernel_batched_gemm_e_permute_xdl(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
EDataType* __restrict__ p_e_grid,
const index_t batch_count,
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const CDEElementwiseOperation cde_element_op,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map)
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const index_t num_blocks_per_batch =
......@@ -81,40 +83,37 @@ __global__ void
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
ck::Tuple<>{},
p_c_grid + c_batch_offset,
p_shared,
a_element_op,
b_element_op,
c_element_op,
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
ck::StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
0>{},
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map);
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
ck::Tuple<>{},
p_e_grid + e_batch_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ck::Tuple<>{},
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = p_e_grid;
ignore = batch_count;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = cde_element_op;
ignore = compute_ptr_offset_of_batch;
ignore = block_2_ctile_map;
ignore = block_2_etile_map;
#endif
}
......@@ -122,51 +121,57 @@ template <typename ALayout,
typename BLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename CShuffleDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
ck::index_t NumPrefetch,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
index_t NumPrefetch,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1,
bool ABlockLdsAddExtraM,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementwiseOperation,
struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
CDEElementwiseOperation>
{
using DeviceOp = DeviceBatchedGemmEPermuteXdl;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
......@@ -181,95 +186,10 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
}
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
......@@ -284,142 +204,16 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
}
}();
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
static auto
MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N)
MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N)
{
const auto c_grid_desc_mraw_nraw = [&]() {
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(stride_M, stride_N));
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto e_grid_desc_mraw_nraw =
make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), make_tuple(stride_M, stride_N));
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
static auto MakeEGridDescriptor_G0_G1_M_N(index_t G0,
......@@ -489,9 +283,9 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
}
}
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1));
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1, 1));
using EGridDesc_G0_G1_M_N = decltype(MakeEGridDescriptor_G0_G1_M_N(1, 1, 1, 1, 1, 1, 1, 1));
struct ComputePtrOffsetOfStridedBatch
......@@ -529,19 +323,20 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
};
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle<
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CDataType, // CShuffleDataType,
CShuffleDataType,
ck::Tuple<>, // DsDataType,
CDataType, // EDataType,
EDataType, // EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AGridDesc_M_K,
BGridDesc_N_K,
Tuple<>,
EGridDesc_M_N,
NumPrefetch,
BlockSize,
MPerBlock,
......@@ -560,7 +355,7 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
......@@ -568,118 +363,135 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
using Block2CTileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}));
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
// Argument
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
EDataType* p_e_grid,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
CDEElementwiseOperation cde_element_op,
index_t BatchCount)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
p_e_grid_{p_e_grid},
BatchCount_(BatchCount),
a_grid_desc_k0_m_k1_{
DeviceBatchedGemmCPermuteXdl::MakeAGridDescriptor_AK0_M_AK1(M, K, stride_A)},
b_grid_desc_k0_n_k1_{
DeviceBatchedGemmCPermuteXdl::MakeBGridDescriptor_BK0_N_BK1(K, N, stride_B)},
c_grid_desc_m_n_{DeviceBatchedGemmCPermuteXdl::MakeCGridDescriptor_M_N(
batched_gemm_c_permute_desc.M_,
batched_gemm_c_permute_desc.N_,
batched_gemm_c_permute_desc.stride_M_,
batched_gemm_c_permute_desc.stride_N_)},
e_grid_desc_g0_g1_m_n_{DeviceBatchedGemmCPermuteXdl::MakeEGridDescriptor_G0_G1_M_N(
batched_gemm_c_permute_desc.G0_,
batched_gemm_c_permute_desc.G1_,
batched_gemm_c_permute_desc.M_,
batched_gemm_c_permute_desc.N_,
batched_gemm_c_permute_desc.stride_G0_,
batched_gemm_c_permute_desc.stride_G1_,
batched_gemm_c_permute_desc.stride_M_,
batched_gemm_c_permute_desc.stride_N_)},
c_grid_desc_mblock_mperblock_nblock_nperblock{},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(M, K, stride_A)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(K, N, stride_B)},
e_grid_desc_m_n_{
DeviceOp::MakeEGridDescriptor_M_N(batched_gemm_e_permute_desc.M_,
batched_gemm_e_permute_desc.N_,
batched_gemm_e_permute_desc.stride_M_,
batched_gemm_e_permute_desc.stride_N_)},
a_grid_desc_ak0_m_ak1_{
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
e_grid_desc_mblock_mperblock_nblock_nperblock{},
e_grid_desc_g0_g1_m_n_{
DeviceOp::MakeEGridDescriptor_G0_G1_M_N(batched_gemm_e_permute_desc.G0_,
batched_gemm_e_permute_desc.G1_,
batched_gemm_e_permute_desc.M_,
batched_gemm_e_permute_desc.N_,
batched_gemm_e_permute_desc.stride_G0_,
batched_gemm_e_permute_desc.stride_G1_,
batched_gemm_e_permute_desc.stride_M_,
batched_gemm_e_permute_desc.stride_N_)},
compute_ptr_offset_of_batch_{
type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
type_convert<index_t>(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()),
type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()),
e_grid_desc_g0_g1_m_n_},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(c_grid_desc_m_n_)},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
cde_element_op_{cde_element_op}
{
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
ck::Tuple<>{},
e_grid_desc_m_n_,
block_2_etile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock =
e_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
e_grid_desc_m_n_);
}
}
void Print() const
{
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
std::cout << "C[M, N]: " << e_grid_desc_m_n_ << std::endl;
}
// private:
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
EDataType* p_e_grid_;
// batch count
index_t BatchCount_;
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
// tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
EGridDesc_M_N e_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock;
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock;
// for calculating Batch offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
Block2CTileMap block_2_ctile_map_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
CDEElementwiseOperation cde_element_op_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceBatchedGemmCPermuteXdl::Argument;
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
ck::Tuple<>{},
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_))
{
throw std::runtime_error(
"wrong! GridwiseBatchedGemmCPermute_km_kn_m0m1n0n1_xdlops_v2r3 has invalid "
......@@ -687,26 +499,24 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
}
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.BatchCount_;
const auto K =
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
float ave_time = 0;
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_gemm_c_permute_xdl<
const auto kernel = kernel_batched_gemm_e_permute_xdl<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceBatchedGemmCPermuteXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceBatchedGemmCPermuteXdl::BGridDesc_K0_N_K1>,
EDataType,
remove_reference_t<DeviceOp::AGridDesc_AK0_M_AK1>,
remove_reference_t<DeviceOp::BGridDesc_BK0_N_BK1>,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
CDEElementwiseOperation,
ComputePtrOffsetOfStridedBatch,
remove_reference_t<Block2CTileMap>,
remove_reference_t<Block2ETileMap>,
has_main_k_block_loop_>;
return launch_and_time_kernel(stream_config,
......@@ -716,28 +526,26 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.p_e_grid_,
arg.BatchCount_,
arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.cde_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.block_2_ctile_map_);
arg.block_2_etile_map_);
};
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
ave_time = launch_kernel(integral_constant<bool, true>{});
return launch_kernel(integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{});
return launch_kernel(integral_constant<bool, false>{});
}
return ave_time;
}
// polymorphic
......@@ -756,10 +564,11 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
static bool IsSupportedArgument(const Argument& arg)
{
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
ck::Tuple<>{},
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
}
// polymorphic
......@@ -770,30 +579,30 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
EDataType* p_e,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
CDEElementwiseOperation cde_element_op,
index_t BatchCount)
{
return Argument{p_a,
p_b,
p_c,
p_e,
M,
N,
K,
stride_A,
stride_B,
batched_gemm_c_permute_desc,
batched_gemm_e_permute_desc,
a_element_op,
b_element_op,
c_element_op,
cde_element_op,
BatchCount};
}
......@@ -803,30 +612,30 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
BatchedGemmCPermuteDesc batched_gemm_c_permute_desc,
BatchedGemmEPermuteDesc batched_gemm_e_permute_desc,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
CDEElementwiseOperation cde_element_op,
index_t BatchCount) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
static_cast<EDataType*>(p_e),
M,
N,
K,
stride_A,
stride_B,
batched_gemm_c_permute_desc,
batched_gemm_e_permute_desc,
a_element_op,
b_element_op,
c_element_op,
cde_element_op,
BatchCount);
}
......@@ -842,7 +651,7 @@ struct DeviceBatchedGemmCPermuteXdl : public DeviceBatchedGemmCPermute<AElementw
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedGemmCPermuteXdl"
str << "DeviceBatchedGemmEPermuteXdl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -46,12 +46,6 @@ struct DeviceGemmBiasCPermute : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGemmBiasCPermutePtr = std::unique_ptr<
DeviceGemmBiasCPermute<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -10,8 +10,9 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_c_permute.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -35,7 +36,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_bias_c_permute(const FloatAB* __restrict__ p_a_grid,
kernel_gemm_bias_e_permute(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatDsPointer p_ds_grid,
FloatE* __restrict__ p_e_grid,
......@@ -99,7 +100,7 @@ template <typename ALayout,
typename CDELayout,
typename ADataType,
typename BDataType,
typename GemmAccDataType,
typename AccDataType,
typename CShuffleDataType,
typename DDataType,
typename EDataType,
......@@ -124,33 +125,36 @@ template <typename ALayout,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOperation,
struct DeviceGemmBiasEPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceGemmBiasCPermute_Xdl;
using DeviceOp = DeviceGemmBiasEPermute_Xdl;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr index_t NumDTensor = I1;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
static constexpr index_t NumDTensor = 1;
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
......@@ -165,95 +169,10 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
}
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto MPad = M - MRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
assert(K % AK1 == 0);
const auto AK0 = K / AK1;
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
assert(KRaw % AK1 == 0);
const auto AK0 = KRaw / AK1;
const auto a_grid_desc_ak0_m_ak1 =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(MRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
......@@ -268,92 +187,7 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
}
}();
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
const auto NPad = N - NRaw;
const auto KPad = K - KRaw;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(NRaw, NPad),
make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
assert(K % BK1 == 0);
const auto BK0 = K / BK1;
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
assert(KRaw % BK1 == 0);
const auto BK0 = KRaw / BK1;
const auto b_grid_desc_bk0_n_bk1 =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(NRaw)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
static auto MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1 d_e_grid_desc)
......@@ -370,73 +204,32 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
index_t stride_N0 = d_e_grid_desc.stride_N0_;
index_t stride_N1 = d_e_grid_desc.stride_N1_;
const auto MRaw = M0 * M1 * M2;
const auto NRaw = N0 * N1;
const auto c_grid_desc_mraw_nraw = [&]() {
const auto c_grid_desc_m0_m1_m2_n0_n1 = make_naive_tensor_descriptor(
const auto e_grid_desc_mraw_nraw = [&]() {
const auto e_grid_desc_m0_m1_m2_n0_n1 = make_naive_tensor_descriptor(
make_tuple(M0, M1, M2, N0, N1),
make_tuple(stride_M0, stride_M1, stride_M2, stride_N0, stride_N1));
return transform_tensor_descriptor(
c_grid_desc_m0_m1_m2_n0_n1,
e_grid_desc_m0_m1_m2_n0_n1,
make_tuple(make_merge_transform(make_tuple(M0, M1, M2)),
make_merge_transform(make_tuple(N0, N1))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}();
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
const auto MPad = M - MRaw;
const auto NPad = N - NRaw;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad),
make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1{}));
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1{}));
using DsGridDesc_M_N = Tuple<EGridDesc_M_N>;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle<
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<DDataType>,
EDataType,
......@@ -444,8 +237,9 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
AGridDesc_M_K,
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
......@@ -480,6 +274,13 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
// Argument
struct Argument : public BaseArgument
{
......@@ -499,12 +300,17 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
CDEElementwiseOperation cde_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{}, // FIXME
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_grid_desc)},
a_grid_desc_ak0_m_ak1_{
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
......@@ -522,8 +328,16 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
// populate pointer, desc for Ds
// D pointer
p_ds_grid_(I0) = static_cast<const DDataType*>(p_d_grid);
// D desc
ds_grid_desc_m_n_(I0) = DeviceOp::MakeEGridDescriptor_M_N(d_grid_desc);
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_))
{
......@@ -531,32 +345,37 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_);
p_ds_grid_(I0) = static_cast<const DDataType*>(p_d_grid);
const auto d_grid_desc_m_n = DeviceOp::MakeEGridDescriptor_M_N(d_grid_desc);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(I0) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
d_grid_desc_m_n);
ds_grid_desc_m_n_[I0]);
}
}
// private:
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
// tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>
ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different
// type from E
EGridDesc_M_N e_grid_desc_m_n_;
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
......@@ -569,8 +388,9 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_))
{
......@@ -586,7 +406,7 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_gemm_bias_c_permute<
const auto kernel = kernel_gemm_bias_e_permute<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
......@@ -596,9 +416,7 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
ck::StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor>,
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2ETileMap,
has_main_loop>;
......@@ -622,18 +440,14 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
arg.block_2_etile_map_);
};
float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
ave_time = launch_kernel(integral_constant<bool, true>{});
return launch_kernel(integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{});
return launch_kernel(integral_constant<bool, false>{});
}
return ave_time;
}
// polymorphic
......@@ -651,8 +465,9 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
}
......@@ -741,7 +556,7 @@ struct DeviceGemmBiasCPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmBiasCPermute_Xdl"
str << "DeviceGemmBiasEPermute_Xdl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -205,12 +205,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{
const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout>::value)
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE));
......@@ -329,7 +329,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
// populate pointer, batch stride, desc for Ds
// populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment