Commit d16063db authored by aska-0096's avatar aska-0096
Browse files

tempsave

parent 98ccb367
......@@ -35,3 +35,8 @@ add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
add_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace ck {
enum struct LoopScheduler
{
Default,
};
constexpr LoopScheduler make_default_loop_scheduler()
{
return LoopScheduler::Default;
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc,
index_t MPerWMMA,
index_t NPerWMMA,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
static constexpr index_t KPerBlock = BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto wmma_gemm = WMMAGemm<FloatAB, MPerWMMA, NPerWMMA, KPack>{};
static constexpr index_t KPerThread = KPerBlock / wmma_gemm.K0PerWMMA;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
wmma_gemm.GetRegSizePerWMMA(),
true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto CalculateAThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, WMMA_a_idx[I1], KPerThread * WMMA_a_idx[I0]);
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, WMMA_b_idx[I1], KPerThread * WMMA_b_idx[I0]);
}
template <index_t m0, index_t n0, index_t WMMA_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<WMMA_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(WMMA_i, blk_i);
constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0, index_t WMMA_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<WMMA_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk4D(WMMA_i, blk_i);
return make_tuple(Number<m0>{},
Number<n0>{},
waveId_m,
waveId_n,
blk_idx[I0],
blk_idx[I1],
blk_idx[I2],
blk_idx[I3]);
}
__host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
{
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!");
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = wmma_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = wmma_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerWMMA>{},
Number<NPerWMMA>{}));
return wmma_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerWMMA>{},
Number<NPerWMMA>{}));
return wmma_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_block_desc_g_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
return wmma_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1()
{
return transform_tensor_descriptor(
AK0MK1BlockDesc{},
make_tuple(
make_pass_through_transform(make_tuple(Number<A_K0>{}, Number<A_K1>{})),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWMMA>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0, 4>{}, Sequence<1, 2, 3>{}));
}
__host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1()
{
return transform_tensor_descriptor(
BK0NK1BlockDesc{},
make_tuple(
make_pass_through_transform(make_tuple(Number<B_K0>{}, Number<B_K1>{})),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWMMA>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0, 4>{}, Sequence<1, 2, 3>{}));
}
static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1();
static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1();
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize());
constexpr auto RepeatDiff = MRepeat - NRepeat;
constexpr auto WmmaK = wmma_gemm.k_per_wmma;
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto iWmmaK){
// Cut to Repeat Retangle to Square, assume MRepeat > NRepeat
static_for<0, RepeatDiff, 1>{}([&](auto iCut){
static_for<0, NRepeat, 1>{}([&](auto iN){
vector_type<FloatAB, WmmaK> a_thread_vec;
vector_type<FloatAB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto iK) {
a_thread_vec.template AsType<FloatAB>()(iK) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(iCut, 0, 0, iK))>{}];
b_thread_vec.template AsType<FloatAB>()(iK) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(iN, 0, 0, iK))>{}];
});
using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type>(),
b_thread_vec.template AsType<wmma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<iWmmaK>{}, iCut, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
});
// Run FIFO fashion loopover in Square
static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop){
static_for<WmmaInnerloop, NRepeat, 1>{}([&](auto iN){
vector_type<FloatAB, WmmaK> a_thread_vec;
vector_type<FloatAB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto iK) {
a_thread_vec.template AsType<FloatAB>()(iK) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(WmmaInnerloop+RepeatDiff, 0, 0, iK))>{}];
b_thread_vec.template AsType<FloatAB>()(iK) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(iN, 0, 0, iK))>{}];
});
using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop+RepeatDiff, iN, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type>(),
b_thread_vec.template AsType<wmma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<iWmmaK>{}, WmmaInnerloop+RepeatDiff, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
static_for<WmmaInnerloop+RepeatDiff, MRepeat, 1>{}([&](auto iM){
vector_type<FloatAB, WmmaK> a_thread_vec;
vector_type<FloatAB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto iK) {
a_thread_vec.template AsType<FloatAB>()(iK) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(iM, 0, 0, iK))>{}];
b_thread_vec.template AsType<FloatAB>()(iK) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(WmmaInnerloop, 0, 0, iK))>{}];
});
using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type>(),
b_thread_vec.template AsType<wmma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<iWmmaK>{}, WmmaInnerloop, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
});
});
}
protected:
// A[M0, M1, M2, K0 = WmmaK]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<WmmaK>{}));
// B[N0, N1, N2, K0 = WmmaK]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<WmmaK>{}));
// C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWMMA()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<1, 1, 1, WmmaK>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<1, 1, 1, WmmaK>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
};
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc,
index_t MPerWMMA,
index_t NPerWMMA,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
LoopScheduler LoopSched>
constexpr auto BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2_Selector()
{
if constexpr(LoopSched == LoopScheduler::Default)
{
return BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2<BlockSize,
FloatAB,
FloatAcc,
AK0MK1BlockDesc,
BK0NK1BlockDesc,
MPerWMMA,
NPerWMMA,
MRepeat,
NRepeat,
KPack>{};
}
};
} // namespace ck
......@@ -36,10 +36,10 @@ template <typename ADataType,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t K1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
ck::index_t MPerWMMA,
ck::index_t NPerWMMA,
ck::index_t MWmmaPerWave,
ck::index_t NWmmaPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
......@@ -217,11 +217,11 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
MPerBlock,
NPerBlock,
K0PerBlock,
MPerXDL,
NPerXDL,
MPerWMMA,
NPerWMMA,
K1,
MXdlPerWave,
NXdlPerWave,
MWmmaPerWave,
NWmmaPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
......@@ -543,10 +543,10 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave
<< MPerWMMA << ", "
<< NPerWMMA << ", "
<< MWmmaPerWave << ", "
<< NWmmaPerWave
<< ">"
<< " NumPrefetch: "
<< NumPrefetch << ", "
......
......@@ -141,7 +141,7 @@ template <
index_t CBlockTransferScalarPerVector_NWaveNPerWmma,
index_t NumGemmKPrefetchStage = 1,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -160,52 +160,35 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K10_K1PerInst()
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_K10_MPerBlock_K1PerInst()
{
constexpr auto inst_max_size = 16 / sizeof(FloatAB);
constexpr auto k1perinst = (K1 <inst_max_size) ? K1 : inst_max_size;
constexpr auto K10 = K1 / k1perinst;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k10_k1perinst = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
// May have static err
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K10, k1perinst), k1perinst);
}
constexpr auto a_block_desc_k0_k10_m_k1perinst = [&]() {
// May have static err
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, K10, Number<MPerBlock>{}, k1perinst), k1perinst);
}();
return a_block_desc_k0_m_k1;
return a_block_desc_k0_k10_m_k1perinst;
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K10_K1PerInst()
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_K10_NPerBlock_K1PerInst()
{
constexpr auto inst_max_size = 16 / sizeof(FloatAB);
constexpr auto k1perinst = (K1 <inst_max_size) ? K1 : inst_max_size;
constexpr auto K10 = K1 / k1perinst;
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k0_n_k1 = [&]() {
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K10, k1perinst), k1perinst);
}
constexpr auto b_block_desc_k0_k10_n_k1perinst = [&]() {
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, K10, Number<NPerBlock>{}, k1perinst), k1perinst);
}();
return b_block_desc_k0_n_k1;
return b_block_desc_k0_k10_n_k1perinst;
}
__host__ __device__ static constexpr auto
......@@ -230,18 +213,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto a_block_desc_k0_k10_m_k1perinst = GetABlockDescriptor_K0PerBlock_K10_MPerBlock_K1PerInst();
constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
constexpr auto b_block_desc_k0_k10_n_k1perinst = GetBBlockDescriptor_K0PerBlock_K10_NPerBlock_K1PerInst();
constexpr auto max_lds_align = K1;
constexpr auto max_lds_align = a_block_desc_k0_k10_m_k1perinst.GetLength(I3);
constexpr auto a_block_space_size_aligned =
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
math::integer_least_multiple(a_block_desc_k0_k10_m_k1perinst.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned =
math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
math::integer_least_multiple(b_block_desc_k0_k10_n_k1perinst.GetElementSpaceSize(), max_lds_align);
constexpr auto c_block_size = 0;
#ifndef DISABLE_C_SHUFFLE
// LDS allocation for C shuffle in LDS
constexpr auto c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma =
GetCBlockDescriptor_MBlock_NWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma();
......@@ -249,7 +234,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
constexpr auto c_block_size =
c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma
.GetElementSpaceSize();
#endif
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB),
c_block_size * sizeof(FloatC));
......@@ -423,42 +408,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k10_k11 = GetABlockDescriptor_K0PerBlock_MPerBlock_K10_K1PerInst();
constexpr auto a_block_desc_k0_k10_m_k1perinst = GetABlockDescriptor_K0PerBlock_MPerBlock_K10_K1PerInst();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k10_k11 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K10_K1PerInst();
constexpr auto b_block_desc_k0_k10_n_k1perinst = GetBBlockDescriptor_K0PerBlock_NPerBlock_K10_K1PerInst();
// lds max alignment
constexpr auto max_lds_align = a_block_desc_k0_m_k10_k11.GetLength(I3);
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_grid_desc_k0_m_k1),
decltype(a_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock,
/* typename SrcElementwiseOperation, */ AElementwiseOperation,
/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough,
/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
/* typename BlockSliceLengths, */ Sequence<K0PerBlock, MPerBlock, K1>,
/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
/* typename SrcData, */ FloatAB,
/* typename DstData, */ FloatAB,
/* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1),
/* typename DstDesc, */ decltype(a_block_desc_k0_k10_m_k1perinst),
/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
/* typename DstDimAccessOrder, */ Sequence<1, 0, 2>,
/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim,
/* index_t DstVectorDim, */ 2,
/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector,
/* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1,
/* index_t SrcScalarStrideInVector, */ 1,
/* index_t DstScalarStrideInVector, */ 1,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun,
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true>(
a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0_m_k1,
a_block_desc_k0_k10_m_k1perinst,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
......@@ -474,7 +459,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
FloatAB,
FloatAB,
decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc_k0_n_k1),
decltype(b_block_desc_k0_k10_n_k1perinst),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
......@@ -488,7 +473,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
b_grid_desc_k0_n_k1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_k0_n_k1,
b_block_desc_k0_k10_n_k1perinst,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
......@@ -504,8 +489,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
BlockwiseGemmWmmaops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
decltype(a_block_desc_k0_k10_m_k1perinst),
decltype(b_block_desc_k0_k10_n_k1perinst),
MPerWmma,
NPerWmma,
MWmmaPerWave,
......@@ -516,14 +501,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned =
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
math::integer_least_multiple(a_block_desc_k0_k10_m_k1perinst.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
static_cast<FloatAB*>(p_shared), a_block_desc_k0_k10_m_k1perinst.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_k0_n_k1.GetElementSpaceSize());
b_block_desc_k0_k10_n_k1perinst.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
......@@ -532,13 +517,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
a_block_desc_k0_m_k1,
a_block_desc_k0_k10_m_k1perinst,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_k0_n_k1,
b_block_desc_k0_n_k1,
b_block_desc_k0_k10_n_k1perinst,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
......@@ -546,7 +531,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
blockwise_gemm,
c_thread_buf,
K0BlockMainLoop);
#ifndef DISABLE_C_SHUFFLE
// shuffle C and write out
{
static_assert(MWmmaPerWave % CShuffleMWmmaPerWavePerShuffle == 0 &&
......@@ -809,6 +794,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
}
});
}
#endif
}
};
......
......@@ -25,15 +25,15 @@ struct wmma_type;
template <>
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_w32>
{
static constexpr index_t m_per_wave = 16;
static constexpr index_t n_per_wave = 16;
static constexpr index_t k_per_wave = 16;
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t wave_size = 32;
static constexpr index_t lane_size = 16;
static constexpr index_t src_data_size = 2;
static constexpr index_t acc_data_size = 4;
static constexpr index_t num_srcregs_per_wave = 8;
static constexpr index_t num_accregs_per_wave = 8;
static constexpr index_t num_srcregs_per_wmma = 8;
static constexpr index_t num_accregs_per_wmma = 8;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
......@@ -45,7 +45,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_w32>
template <typename src_type, typename dst_type, index_t MPerWmma, index_t NPerWmma>
struct WmmaSelector
{
template <typename src_type, typename dst_type, index_t MPerWmma_, index_t NPerWmma_>
template <typename src_type_, typename dst_type_, index_t MPerWmma_, index_t NPerWmma_>
static constexpr auto GetWmma();
template <>
......@@ -89,21 +89,21 @@ struct WmmaSelector
__host__ __device__ constexpr WmmaSelector()
{
static_assert(selected_wmma.m_per_wave == selected_wmma.n_per_wave,
static_assert(selected_wmma.m_per_wmma == selected_wmma.n_per_wmma,
"WRONG! WMMA_M must equal to WMMA_N");
static_assert(selected_wmma.m_per_wave == selected_wmma.k_per_wave,
static_assert(selected_wmma.m_per_wmma == selected_wmma.k_per_wmma,
"WRONG! WMMA_M must equal to WMMA_K");
static_assert(selected_wmma.k_per_wave == 16,
static_assert(selected_wmma.k_per_wmma == 16,
"WRONG! WMMA_M must equal to WMMA_N");
static_assert(selected_wmma.wave_size * selected_wmma.num_accregs_per_wave * selected_wmma.acc_data_size==
selected_wmma.m_per_wave * selected_wmma.n_per_wave * 4,
static_assert(selected_wmma.wave_size * selected_wmma.num_accregs_per_wmma * selected_wmma.acc_data_size==
selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4,
"WRONG! Number of Accumulator Register");
static_assert(selected_wmma.lane_size * selected_wmma.num_srcregs_per_wave * selected_wmma.src_data_size==
selected_wmma.m_per_wave * selected_wmma.k_per_wave * 4,
static_assert(selected_wmma.lane_size * selected_wmma.num_srcregs_per_wmma * selected_wmma.src_data_size==
selected_wmma.m_per_wmma * selected_wmma.k_per_wmma * 4,
"WRONG! Number of Source Register");
}
};
......@@ -126,20 +126,12 @@ struct WmmaGemm
using CIndex = MultiIndex<2>;
using CIndex4D = MultiIndex<4>;
__device__ static constexpr index_t GetNumBlks() { return wmma_instr.num_output_blks; }
__device__ static constexpr index_t GetNumXdlops()
{
return MPerWmma * NPerWmma /
(wmma_instr.m_per_blk * wmma_instr.n_per_blk * wmma_instr.num_output_blks);
}
__host__ __device__ constexpr WmmaGemm()
{
static_assert(NPerWmma == 16 && MPerWmma == 16 ,
"Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma");
static_assert(KPack % wmma_instr.k_per_wave == 0, "KPack cannot be divided by k_per_wave");
static_assert(KPack == wmma_instr.k_per_wmma, "KPack should be k_per_wmma");
}
// XDL output supporting C = A * B
......@@ -267,79 +259,43 @@ struct WmmaGemm
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| (is_same<src_type, int4_t>::value && is_same<dst_type, int32_t>::value)
#endif
,
"base type couple must be (half, float), (bhalf, float), (half, half),
(bhalf, bhalf), (int8, int32) or (int4, int32)!");
static_for<0, KPack / wmma_instr.k_per_wave, 1>{}([&](auto k) {
if constexpr(!TransposeC)
{
wmma_instr.template run<MPerWmma, NPerWmma>(
p_a_wave[k], p_b_wave[k], p_c_thread);
}
else
{
wmma_instr.template run<MPerWmma, NPerWmma>(
p_b_wave[k], p_a_wave[k], p_c_thread);
}
});
,"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), (int8, int32) or (int4, int32)!");
if constexpr(!TransposeC)
{
wmma_instr.template run<MPerWmma, NPerWmma>(
p_a_wave[0], p_b_wave[0], p_c_thread);
}
else
{
wmma_instr.template run<MPerWmma, NPerWmma>(
p_b_wave[0], p_a_wave[0], p_c_thread);
}
}
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; }
__device__ static auto GetBlkIdx()
__device__ static auto GetLaneIdHigh()
{
const auto laneId = GetLaneId();
constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(
make_tuple(1, wmma_instr.num_input_blks, wmma_instr.num_threads_per_blk))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto blk_idx =
threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId));
const auto blk_id = blk_idx[I1];
const auto blk_td = blk_idx[I2];
return GetLaneId() / 16;
}
return make_tuple(blk_id, blk_td);
__device__ static auto GetLaneIdLow()
{
return GetLaneId() % 16;
}
__device__ static auto GetSwizzledLaneIdLow()
{
return ((GetLaneIdLow() & 1) << 3 ) | (GetLaneIdLow() >> 1);
}
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
{
const auto laneId = GetLaneId();
const auto blk_idx = GetBlkIdx();
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
if constexpr(wmma_instr.is_k_reduction)
{
return make_tuple(blk_id, blk_td);
}
else
{
return make_tuple(0, laneId);
}
return make_tuple(0, GetSwizzledLaneIdLow());
}
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
{
const auto laneId = GetLaneId();
const auto blk_idx = GetBlkIdx();
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
if constexpr(wmma_instr.is_k_reduction)
{
return make_tuple(blk_id, blk_td);
}
else
{
return make_tuple(0, laneId);
}
return make_tuple(0, GetLaneIdLow());
}
__device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
......@@ -365,12 +321,12 @@ struct WmmaGemm
return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td};
}
static constexpr auto mfma = MfmaSelector<base_type, MPerWmma, NPerWmma>{};
static constexpr auto wmma = WmmaSelector<src_type, dst_type, MPerWmma, NPerWmma>{};
static constexpr auto wmma_instr = mfma.selected_mfma;
static constexpr auto wmma_instr = wmma.selected_wmma;
static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
static constexpr auto KPerXdlops = wmma.GetKPerXdlops();
static constexpr auto K1PerXdlops = wmma.GetK1PerXdlops();
static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
__host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment