Commit f4ea00fc authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Make sure methods are only invoked on right place

parent 880bbc45
......@@ -217,7 +217,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
DeviceOp::MakeDescriptor_M({M_, N_}, {I1, StrideC_}, grid_size, BlockSize);
}
p_aux_2_grid_ = p_workspace + Parent::c_grid_desc_m_n.GetElementSpaceSize();
p_aux_2_grid_ = p_workspace + GetCElementSpaceSize(M_, N_, StrideC_);
}
// private:
......@@ -561,6 +561,14 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
return str.str();
}
static std::size_t GetCElementSpaceSize(index_t M, index_t N, index_t StrideC)
{
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(
M, GridwiseGemm::CalculateMPadded(M), N, GridwiseGemm::CalculateNPadded(N), StrideC);
return c_grid_desc_m_n.GetElementSpaceSize();
}
std::size_t GetWorkspaceSize(index_t M,
index_t N,
[[maybe_unused]] index_t K,
......@@ -568,10 +576,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
[[maybe_unused]] index_t StrideB,
index_t StrideC) override
{
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(
M, GridwiseGemm::CalculateMPadded(M), N, GridwiseGemm::CalculateNPadded(N), StrideC);
return 2 * sizeof(CDataType) * c_grid_desc_m_n.GetElementSpaceSize();
return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC);
}
};
......
......@@ -170,7 +170,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
}
__host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
__device__ static auto MakeAGridDescriptor_AK0_M_AK1(
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
{
const auto a_grid_desc_mraw_kraw = [&]() {
......@@ -252,7 +252,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
}
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
__device__ static auto MakeBGridDescriptor_BK0_N_BK1(
index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
{
const auto b_grid_desc_nraw_kraw = [&]() {
......@@ -387,10 +387,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1, 1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1, 1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1));
// Argument
struct Argument : public tensor_operation::device::BaseArgument
{
......@@ -419,7 +415,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
}
__host__ __device__ void Print() const
__host__ void Print() const
{
std::cout << "arg {"
<< "M:" << M << ", "
......@@ -435,10 +431,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<< "BK0:" << BK0 << "}" << std::endl;
}
__host__ __device__ Argument(const Argument&) = default;
__host__ __device__ ~Argument() override {}
index_t M;
index_t N;
index_t K;
......@@ -456,7 +448,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
......@@ -464,7 +456,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1Number, AK1Number, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
__device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
......@@ -472,8 +464,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1Number, BK1Number, I1));
}
__host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
__device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
......@@ -488,7 +479,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
__device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
......@@ -516,7 +507,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
__host__ static constexpr bool CheckValidity(const Argument& karg)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
......@@ -601,7 +592,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
......@@ -609,7 +600,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
template <typename CGridDesc>
__host__ __device__ static constexpr auto
__device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
......@@ -631,33 +622,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// return block_id to C matrix tile idx (m0, n0) mapping
using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
__host__ __device__ static void print_bytes(const uint8_t* memory, std::size_t size)
{
(void)memory;
(void)size;
for(std::size_t idx = 0; idx < size; ++idx)
{
if(idx % 10 == 0)
{
printf("\n");
}
printf("0x%02X ", static_cast<unsigned>(memory[idx]));
}
printf("\n");
}
template <typename T>
__host__ __device__ static void print_bytes(const T& obj)
{
uint8_t memory[sizeof(T)];
memcpy(memory, &obj, sizeof(T));
print_bytes(memory, sizeof(T));
}
template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment