Commit 482ca684 authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'dev/a8w8_b_preshuffle' of...

Merge branch 'dev/a8w8_b_preshuffle' of https://github.com/ROCm/composable_kernel into add_a8w8_preshuffle_ckprofiler
parents 74ef5021 db843529
...@@ -516,6 +516,10 @@ include_directories(BEFORE ...@@ -516,6 +516,10 @@ include_directories(BEFORE
) )
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV)
add_compile_options(-Werror)
add_compile_options(-Weverything)
endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
......
...@@ -137,7 +137,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -137,7 +137,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::a_block_desc_m0_m1_m2_k; using Base::a_block_desc_m0_m1_m2_k;
using Base::b_block_desc_n0_n1_n2_k;
using Base::AMmaKStride; using Base::AMmaKStride;
using Base::BMmaKStride; using Base::BMmaKStride;
...@@ -271,10 +270,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -271,10 +270,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
typename ABlockBuffer, typename ABlockBuffer,
typename ABlockTransferStep, typename ABlockTransferStep,
typename BGridDesc, typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer, typename BBlockTransfer,
typename BGridBuffer, typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep, typename BBlockTransferStep,
typename CThreadBuffer> typename CThreadBuffer>
__device__ void Run(const AGridDesc& a_grid_desc, __device__ void Run(const AGridDesc& a_grid_desc,
...@@ -285,10 +282,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -285,10 +282,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
ABlockBuffer& a_block_buf1, ABlockBuffer& a_block_buf1,
const ABlockTransferStep& a_block_copy_step, const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc, const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy, BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf, const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) const index_t num_loop) const
...@@ -296,8 +291,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -296,8 +291,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
a_thread_desc_.GetElementSpaceSize()); a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize());
// Global prefetch 1 // Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
......
...@@ -124,21 +124,26 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -124,21 +124,26 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock =
CDEShuffleBlockTransferScalarPerVectors{}[I0]; CDEShuffleBlockTransferScalarPerVectors{}[I0];
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{}; static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{}; static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1Number = Number<AK1Value>{}; static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{}; static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr auto BlockSizeNumber = Number<BlockSize>{}; static constexpr auto BlockSizeNumber = Number<BlockSize>{};
static constexpr index_t NLane = 32;
static constexpr index_t NWave = 4;
static constexpr index_t KLane = 2;
static constexpr index_t KRepeat = 8;
static_assert(NLane * NWave * KLane == BlockSize);
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
using mfma_selector = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>;
static constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk);
static constexpr index_t KLane =
mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops();
static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
static constexpr index_t NLane = NPerXdl;
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
static_assert(NLane * NWave * KLane == BlockSize);
static_assert(NXdlPerWave == 1, "only 1 validated now, tbd next week");
static constexpr auto MakeDsGridPointer() static constexpr auto MakeDsGridPointer()
{ {
return generate_tuple( return generate_tuple(
...@@ -152,10 +157,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -152,10 +157,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
using DsGridPointer = decltype(MakeDsGridPointer()); using DsGridPointer = decltype(MakeDsGridPointer());
static constexpr index_t KPack = math::max(
math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>::selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
...@@ -321,10 +322,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -321,10 +322,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{ {
constexpr index_t NKSWIZZLE_V = BlockSize * KPack; constexpr index_t NkSwizzle = BlockSize * KPack;
constexpr index_t NKSWIZZLE_N = Number<NKSWIZZLE_V>{}; constexpr index_t NkSwizzleNumber = Number<NkSwizzle>{};
return make_naive_tensor_descriptor(make_tuple(N0, K0, NKSWIZZLE_N), return make_naive_tensor_descriptor(make_tuple(N0, K0, NkSwizzleNumber),
make_tuple(K0 * NKSWIZZLE_V, NKSWIZZLE_N, I1)); make_tuple(K0 * NkSwizzle, NkSwizzleNumber, I1));
} }
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
...@@ -422,9 +423,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -422,9 +423,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{ {
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
} }
template <typename ELayout> template <typename ELayout>
...@@ -942,7 +941,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -942,7 +941,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{ {
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed( make_naive_tensor_descriptor_packed(
...@@ -982,17 +980,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -982,17 +980,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment // lds max alignment
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple( constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
// constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
// b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle in LDS // LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
...@@ -1296,9 +1289,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1296,9 +1289,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
const index_t n_block_data_idx_on_grid = const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * (NPerBlock / NLane / NWave)); __builtin_amdgcn_readfirstlane(block_n_id * (NPerBlock / NLane / NWave));
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
...@@ -1369,20 +1359,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1369,20 +1359,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
// Cast after lds // Cast after lds
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto a_block_buf1 = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf1 = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeA*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<LDSTypeA*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeB*>(p_shared) +
a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB),
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, KRepeat, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, KRepeat, 0);
...@@ -1403,10 +1385,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1403,10 +1385,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
a_block_buf1, a_block_buf1,
a_block_slice_copy_step, a_block_slice_copy_step,
b_grid_desc_bpreshuffled, b_grid_desc_bpreshuffled,
b_block_desc_bk0_n_bk1,
b_blockwise_copy, b_blockwise_copy,
b_grid_buf, b_grid_buf,
b_block_buf,
b_block_slice_copy_step, b_block_slice_copy_step,
c_thread_buf, c_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
...@@ -1418,7 +1398,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1418,7 +1398,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
"wrong!"); "wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it! // TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment