Commit 482ca684 authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'dev/a8w8_b_preshuffle' of...

Merge branch 'dev/a8w8_b_preshuffle' of https://github.com/ROCm/composable_kernel into add_a8w8_preshuffle_ckprofiler
parents 74ef5021 db843529
......@@ -516,6 +516,10 @@ include_directories(BEFORE
)
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV)
add_compile_options(-Werror)
add_compile_options(-Weverything)
endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
......
......@@ -137,7 +137,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::a_block_desc_m0_m1_m2_k;
using Base::b_block_desc_n0_n1_n2_k;
using Base::AMmaKStride;
using Base::BMmaKStride;
......@@ -271,10 +270,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer>
__device__ void Run(const AGridDesc& a_grid_desc,
......@@ -285,10 +282,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
ABlockBuffer& a_block_buf1,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
index_t num_loop) const
......@@ -296,8 +291,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
__builtin_amdgcn_sched_barrier(0);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize());
// Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
......
......@@ -124,21 +124,26 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock =
CDEShuffleBlockTransferScalarPerVectors{}[I0];
// K1 should be Number<...>
static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr auto BlockSizeNumber = Number<BlockSize>{};
static constexpr index_t NLane = 32;
static constexpr index_t NWave = 4;
static constexpr index_t KLane = 2;
static constexpr index_t KRepeat = 8;
static_assert(NLane * NWave * KLane == BlockSize);
static constexpr index_t NumDTensor = DsDataType::Size();
using mfma_selector = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>;
static constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk);
static constexpr index_t KLane =
mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops();
static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
static constexpr index_t NLane = NPerXdl;
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
static_assert(NLane * NWave * KLane == BlockSize);
static_assert(NXdlPerWave == 1, "only 1 validated now, tbd next week");
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
......@@ -152,10 +157,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
using DsGridPointer = decltype(MakeDsGridPointer());
static constexpr index_t KPack = math::max(
math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>::selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
......@@ -321,10 +322,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
{
constexpr index_t NKSWIZZLE_V = BlockSize * KPack;
constexpr index_t NKSWIZZLE_N = Number<NKSWIZZLE_V>{};
return make_naive_tensor_descriptor(make_tuple(N0, K0, NKSWIZZLE_N),
make_tuple(K0 * NKSWIZZLE_V, NKSWIZZLE_N, I1));
constexpr index_t NkSwizzle = BlockSize * KPack;
constexpr index_t NkSwizzleNumber = Number<NkSwizzle>{};
return make_naive_tensor_descriptor(make_tuple(N0, K0, NkSwizzleNumber),
make_tuple(K0 * NkSwizzle, NkSwizzleNumber, I1));
}
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
......@@ -422,9 +423,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__host__ __device__ static constexpr auto
MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
}
template <typename ELayout>
......@@ -942,7 +941,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
......@@ -982,17 +980,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
// constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
// b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
......@@ -1296,9 +1289,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * (NPerBlock / NLane / NWave));
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
......@@ -1369,20 +1359,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
ck::tensor_operation::element_wise::PassThrough{});
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
// Cast after lds
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto a_block_buf1 = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeA*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeB*>(p_shared) +
a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB),
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, KRepeat, 0);
......@@ -1403,10 +1385,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
a_block_buf1,
a_block_slice_copy_step,
b_grid_desc_bpreshuffled,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
c_thread_buf,
num_k_block_main_loop);
......@@ -1418,7 +1398,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment