Commit 9e1dd262 authored by Jing Zhang's avatar Jing Zhang
Browse files

clean code

parent 91075f0f
......@@ -193,6 +193,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
static constexpr index_t NumDTensor = DsDataType::Size();
static const index_t k_batch = 1;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
......@@ -574,15 +576,15 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
});
// tensor descriptors for problem definiton
const auto a_grid_desc_m_k = DeviceOp::MakeAGridDescriptor_M_K(M, K, StrideA);
const auto b_grid_desc_n_k = DeviceOp::MakeBGridDescriptor_N_K(K, N, StrideB);
// const auto a_grid_desc_m_k = DeviceOp::MakeAGridDescriptor_M_K(M, K, StrideA);
// const auto b_grid_desc_n_k = DeviceOp::MakeBGridDescriptor_N_K(K, N, StrideB);
DsGridDesc_M_N ds_grid_desc_m_n;
// DsGridDesc_M_N ds_grid_desc_m_n;
std::array<index_t, NumDTensor> StrideDs;
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
// using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
if(gemm_descs[i].stride_Ds_.size() != NumDTensor)
{
......@@ -590,9 +592,9 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
"wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor");
}
StrideDs[j] = gemm_descs[i].stride_Ds_[j];
ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
M, N, gemm_descs[i].stride_Ds_[j]);
StrideDs[j] = gemm_descs[i].stride_Ds_[j];
// ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
// M, N, gemm_descs[i].stride_Ds_[j]);
});
#if 0
......@@ -619,32 +621,34 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
grid_size_ += grid_size_grp;
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_n_k,
ds_grid_desc_m_n,
e_grid_desc_m_n,
local_b2c_tile_map))
// check block-to-E-tile
if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n))
{
gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{
p_As.size() == 0 ? nullptr : p_As[i],
p_Bs.size() == 0 ? nullptr : p_Bs[i],
p_ds_grid,
p_Es[i],
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideC,
});
throw std::runtime_error("wrong! block_2_etile_map validation failed");
}
else
if(!GridwiseGemm::
template CheckValidity<ALayout, BLayout, DsLayout, ELayout, GemmSpec>(
M, N, K, StrideA, StrideB, StrideDs, StrideC, 1))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{
p_As.size() == 0 ? nullptr : p_As[i],
p_Bs.size() == 0 ? nullptr : p_Bs[i],
p_ds_grid,
p_Es[i],
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideC,
});
group_id++;
}
}
......@@ -682,7 +686,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
const auto KPad =
GridwiseGemm::CalculateKPadded(arg.gemm_desc_kernel_arg_[i].K_, 1);
GridwiseGemm::CalculateKPadded(arg.gemm_desc_kernel_arg_[i].K_, k_batch);
if(GridwiseGemm::CalculateHasMainKBlockLoop(KPad) != has_main_k_block_loop)
{
......
......@@ -16,6 +16,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
namespace ck {
......@@ -393,6 +394,71 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
e_grid_desc_m_n);
}
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
GemmSpecialization GemmSpec>
__host__ __device__ static constexpr bool
CheckValidity(const index_t M,
const index_t N,
const index_t K,
const index_t StrideA,
const index_t StrideB,
const std::array<index_t, NumDTensor> StrideDs,
const index_t StrideE,
const index_t KBatch = 1)
{
const auto a_grid_desc_kbatch_ak0_m_ak1 =
MakeAGridDescriptor_KBatch_AK0_M_AK1<ALayout, GemmSpec>(M, K, StrideA, KBatch);
const auto b_grid_desc_kbatch_bk0_n_bk1 =
MakeBGridDescriptor_KBatch_BK0_N_BK1<BLayout, GemmSpec>(K, N, StrideB, KBatch);
ignore = StrideDs;
// using DsGridDesc_M_N =
// remove_cvref_t<decltype(MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>({}, {}, {}))>;
// DsGridDesc_M_N ds_grid_desc_m_n;
// static_for<0, NumDTensor, 1>{}([&](auto j) {
// using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
// ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout, GemmSpec>(M, N, StrideDs[j]);
//});
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
#if 0
// check tile size
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
return false;
}
#endif
// check gridwise gemm pipeline
const auto num_k_loop = K / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
}
return true;
}
#if 0
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename AGridDesc_M_K,
typename BGridDesc_N_K,
......@@ -464,6 +530,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
return true;
}
#endif
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
......@@ -616,12 +683,23 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t kbatch_id = 0; //__builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// if(get_thread_local_1d_id() == 0)
//{
// printf("%d %d %d %d\n",
// get_block_1d_id(),
// kbatch_id,
// block_work_idx[I1],
// block_work_idx[I2]);
//}
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
......@@ -633,8 +711,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
constexpr auto b_block_desc_kbatch_bk0_n_bk1 =
GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1();
const index_t kbatch_id = 0;
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment