Commit f4ea00fc authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Make sure methods are only invoked on right place

parent 880bbc45
...@@ -217,7 +217,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -217,7 +217,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
DeviceOp::MakeDescriptor_M({M_, N_}, {I1, StrideC_}, grid_size, BlockSize); DeviceOp::MakeDescriptor_M({M_, N_}, {I1, StrideC_}, grid_size, BlockSize);
} }
p_aux_2_grid_ = p_workspace + Parent::c_grid_desc_m_n.GetElementSpaceSize(); p_aux_2_grid_ = p_workspace + GetCElementSpaceSize(M_, N_, StrideC_);
} }
// private: // private:
...@@ -561,6 +561,14 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -561,6 +561,14 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
return str.str(); return str.str();
} }
static std::size_t GetCElementSpaceSize(index_t M, index_t N, index_t StrideC)
{
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(
M, GridwiseGemm::CalculateMPadded(M), N, GridwiseGemm::CalculateNPadded(N), StrideC);
return c_grid_desc_m_n.GetElementSpaceSize();
}
std::size_t GetWorkspaceSize(index_t M, std::size_t GetWorkspaceSize(index_t M,
index_t N, index_t N,
[[maybe_unused]] index_t K, [[maybe_unused]] index_t K,
...@@ -568,10 +576,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -568,10 +576,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
[[maybe_unused]] index_t StrideB, [[maybe_unused]] index_t StrideB,
index_t StrideC) override index_t StrideC) override
{ {
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N( return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC);
M, GridwiseGemm::CalculateMPadded(M), N, GridwiseGemm::CalculateNPadded(N), StrideC);
return 2 * sizeof(CDataType) * c_grid_desc_m_n.GetElementSpaceSize();
} }
}; };
......
...@@ -170,7 +170,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -170,7 +170,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
} }
} }
__host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
{ {
const auto a_grid_desc_mraw_kraw = [&]() { const auto a_grid_desc_mraw_kraw = [&]() {
...@@ -252,7 +252,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -252,7 +252,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
} }
} }
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
{ {
const auto b_grid_desc_nraw_kraw = [&]() { const auto b_grid_desc_nraw_kraw = [&]() {
...@@ -387,10 +387,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -387,10 +387,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
} }
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1, 1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1, 1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1));
// Argument // Argument
struct Argument : public tensor_operation::device::BaseArgument struct Argument : public tensor_operation::device::BaseArgument
{ {
...@@ -419,7 +415,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -419,7 +415,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{ {
} }
__host__ __device__ void Print() const __host__ void Print() const
{ {
std::cout << "arg {" std::cout << "arg {"
<< "M:" << M << ", " << "M:" << M << ", "
...@@ -435,10 +431,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -435,10 +431,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<< "BK0:" << BK0 << "}" << std::endl; << "BK0:" << BK0 << "}" << std::endl;
} }
__host__ __device__ Argument(const Argument&) = default;
__host__ __device__ ~Argument() override {}
index_t M; index_t M;
index_t N; index_t N;
index_t K; index_t K;
...@@ -456,7 +448,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -456,7 +448,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
...@@ -464,7 +456,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -464,7 +456,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1Number, AK1Number, I1)); make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1Number, AK1Number, I1));
} }
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{ {
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
...@@ -472,8 +464,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -472,8 +464,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1Number, BK1Number, I1)); make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1Number, BK1Number, I1));
} }
__host__ __device__ static constexpr auto __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{ {
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
...@@ -488,7 +479,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -488,7 +479,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
} }
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
...@@ -516,7 +507,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -516,7 +507,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ __device__ static constexpr bool CheckValidity(const Argument& karg) __host__ static constexpr bool CheckValidity(const Argument& karg)
{ {
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
...@@ -601,7 +592,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -601,7 +592,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return true; return true;
} }
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{ {
const index_t num_loop = K / KPerBlock; const index_t num_loop = K / KPerBlock;
...@@ -609,7 +600,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -609,7 +600,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
} }
template <typename CGridDesc> template <typename CGridDesc>
__host__ __device__ static constexpr auto __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_grid_desc_m_n) MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_grid_desc_m_n)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0); const auto M = c_grid_desc_m_n.GetLength(I0);
...@@ -631,33 +622,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -631,33 +622,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>; using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
__host__ __device__ static void print_bytes(const uint8_t* memory, std::size_t size)
{
(void)memory;
(void)size;
for(std::size_t idx = 0; idx < size; ++idx)
{
if(idx % 10 == 0)
{
printf("\n");
}
printf("0x%02X ", static_cast<unsigned>(memory[idx]));
}
printf("\n");
}
template <typename T>
__host__ __device__ static void print_bytes(const T& obj)
{
uint8_t memory[sizeof(T)];
memcpy(memory, &obj, sizeof(T));
print_bytes(memory, sizeof(T));
}
template <bool HasMainKBlockLoop> template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment