Commit 9739ede0 authored by aska-0096's avatar aska-0096
Browse files

temp save

parent 0a808724
...@@ -35,7 +35,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -35,7 +35,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| // ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 8>, 8>; < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
#include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp" #include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tensor_description/tensor_adaptor.hpp"
#define CK_MNK_LOOP
namespace ck { namespace ck {
template <index_t BlockSize, template <index_t BlockSize,
...@@ -25,7 +27,730 @@ template <index_t BlockSize, ...@@ -25,7 +27,730 @@ template <index_t BlockSize,
* C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs * C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* KPACK == WMMA_K = 16 * KPACK == WMMA_K = 16
*/ */
struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto WmmaK = Number<16>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
static constexpr index_t WaveSize = 32;
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
static constexpr index_t KPerBlock =
BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto wmma_gemm = WmmaGemm<FloatAB, FloatAcc, MPerWMMA, NPerWMMA, KPack>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
wmma_gemm.GetRegSizePerWmma(),
true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto CalculateAThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|MWave |MLane |KPack
return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0);
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
// |KRepeat |NRepeat|Nwave |NLane |KPack
return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0);
}
template <index_t m0, index_t n0>
__device__ static auto CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
__host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle()
{
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!");
}
// Thread level, register decriptor. Vector-write
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{},
I1,
MSubGroup,
Number<NRepeat>{},
I1,
NThreadPerSubGroup,
MAccVgprs));
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(
make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
}
// Thread level, register decriptor. Per-pixel write
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_MAccVgprs_NRepeat_NWave_NThreadPerSubGroup()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |MAccVgprs |NRepeat |NWave
// |NThreadPerSubGroup
make_tuple(Number<MRepeat>{},
I1,
MSubGroup,
MAccVgprs,
Number<NRepeat>{},
I1,
NThreadPerSubGroup));
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(
const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(
make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
}
// Provide dimension size
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<MWaves>{},
Number<MPerWMMA>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<NPerWMMA>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
}
__host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1()
{
return transform_tensor_descriptor(
AK0MK1BlockDesc{},
make_tuple(make_pass_through_transform(Number<A_K0>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWMMA>{})),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
__host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1()
{
return transform_tensor_descriptor(
BK0NK1BlockDesc{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWMMA>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1();
static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1();
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
// auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
// a_thread_desc_.GetElementSpaceSize());
// auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
// b_thread_desc_.GetElementSpaceSize());
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAB,
MRepeat,
WmmaK,
true>
a_thread_buf;
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAB,
NRepeat,
WmmaK,
true>
b_thread_buf;
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1>{}, m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
a_thread_buf.GetVectorTypeReference(Number<m0*WmmaK>{}).template AsType<FloatAB>());
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1>{}, n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
b_thread_buf.GetVectorTypeReference(Number<n0*WmmaK>{}).template AsType<FloatAB>());
// vector_type<FloatAB, WmmaK> a_thread_vec;
// vector_type<FloatAB, WmmaK> b_thread_vec;
// static_for<0, WmmaK, 1>{}([&](auto i) {
// a_thread_vec.template AsType<FloatAB>()(i) =
// a_thread_buf[Number<a_thread_desc_.CalculateOffset(
// make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}];
// b_thread_vec.template AsType<FloatAB>()(i) =
// b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}];
// });
// using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_buf.GetVectorTypeReference(Number<m0*WmmaK>{}),
b_thread_buf.GetVectorTypeReference(Number<n0*WmmaK>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
protected:
// A[M0, M1, M2, K0 = WmmaK]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<WmmaK / A_K1>{}, I1, I1, I1, Number<A_K1>{}));
// B[N0, N1, N2, K0 = WmmaK]
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<WmmaK / B_K1>{}, I1, I1, I1, Number<B_K1>{}));
// C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<WmmaK / A_K1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4>,
4,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<WmmaK / B_K1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4>,
4,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
};
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc,
index_t MPerWMMA,
index_t NPerWMMA,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
/* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
* C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* KPACK == WMMA_K = 16
*/
struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto WmmaK = Number<16>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
static constexpr index_t WaveSize = 32;
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
static constexpr index_t KPerBlock =
BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto wmma_gemm = WmmaGemm<FloatAB, FloatAcc, MPerWMMA, NPerWMMA, KPack>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA);
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
wmma_gemm.GetRegSizePerWmma(),
true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = ThisThreadBlock::GetThreadId();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto CalculateAThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|MWave |MLane |KPack
return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0);
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
// |KRepeat |NRepeat|Nwave |NLane |KPack
return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0);
}
template <index_t m0, index_t n0>
__device__ static auto CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
__host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop()
{
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 &&
NPerBlock % (NPerWMMA * NRepeat) == 0,
"wrong!");
}
// Thread level, register decriptor. Vector-write
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple(Number<MRepeat>{},
I1,
MSubGroup,
Number<NRepeat>{},
I1,
NThreadPerSubGroup,
MAccVgprs));
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(
make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
}
// Thread level, register decriptor. Per-pixel write
__host__ __device__ static constexpr auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_MAccVgprs_NRepeat_NWave_NThreadPerSubGroup()
{
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0];
constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1];
constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |MAccVgprs |NRepeat |NWave
// |NThreadPerSubGroup
make_tuple(Number<MRepeat>{},
I1,
MSubGroup,
MAccVgprs,
Number<NRepeat>{},
I1,
NThreadPerSubGroup));
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(
const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma =
transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(
make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma);
}
// Provide dimension size
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
{
constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<MWaves>{},
Number<MPerWMMA>{},
Number<NRepeat>{},
Number<NWaves>{},
Number<NPerWMMA>{}));
return wmma_gemm
.MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
}
__host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1()
{
return transform_tensor_descriptor(
AK0MK1BlockDesc{},
make_tuple(make_pass_through_transform(Number<A_K0>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWMMA>{})),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
__host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1()
{
return transform_tensor_descriptor(
BK0NK1BlockDesc{},
make_tuple(make_pass_through_transform(Number<B_K0>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWMMA>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}));
}
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1();
static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1();
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize());
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
b_thread_buf);
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
vector_type<FloatAB, WmmaK> a_thread_vec;
vector_type<FloatAB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple((k*WmmaK + i) / A_K1, 0, 0, 0, (k*WmmaK + i) % A_K1))>{}];
b_thread_vec.template AsType<FloatAB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple((k*WmmaK + i) / B_K1, 0, 0, 0, (k*WmmaK + i) % B_K1))>{}];
});
using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
protected:
// A[M0, M1, M2, K0 = WmmaK]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<KPerBlock / A_K1>{}, I1, I1, I1, Number<A_K1>{}));
// B[N0, N1, N2, K0 = WmmaK]
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<KPerBlock / B_K1>{}, I1, I1, I1, Number<B_K1>{}));
// C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<KPerBlock / A_K1, 1, 1, 1, A_K1>,
Sequence<3, 0, 1, 2, 4>,
4,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<KPerBlock / B_K1, 1, 1, 1, B_K1>,
Sequence<3, 0, 1, 2, 4>,
4,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
};
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc,
index_t MPerWMMA,
index_t NPerWMMA,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
/* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
* C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* KPACK == WMMA_K = 16
*/
struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -125,7 +850,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -125,7 +850,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
return make_tuple(c_thread_m, c_thread_n); return make_tuple(c_thread_m, c_thread_n);
} }
__host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle() __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO()
{ {
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(), BK0NK1BlockDesc::IsKnownAtCompileTime(),
...@@ -283,33 +1008,26 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -283,33 +1008,26 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
// constexpr auto RepeatDiff = MRepeat - NRepeat; constexpr auto RepeatDiff = MRepeat - NRepeat;
// debug_hexprinter(0xffffffff, a_thread_buf[Number<a_thread_desc_.CalculateOffset(
// make_tuple(0, 0, 0, 0,0))>{}], "Avalue ");
/* First local prefetch, move out of blockwise operation.
static_for<0, NRepeat, 1>{}([&](auto iN){
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_thread_buf);
});
static_for<0, MRepeat, 1>{}([&](auto iN){
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_thread_buf);
});
*/
/*
static_for<0, KPerBlock, WmmaK>{}([&](auto iWmmaK){ static_for<0, KPerBlock, WmmaK>{}([&](auto iWmmaK){
// Cut to Repeat Retangle to Square, assume MRepeat > NRepeat
// Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat
static_for<0, RepeatDiff, 1>{}([&](auto iCut){ static_for<0, RepeatDiff, 1>{}([&](auto iCut){
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<iWmmaK/A_K1>{}, Number<iCut>{}, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, Number<iCut>{}, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto iN){ static_for<0, NRepeat, 1>{}([&](auto iN){
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<iWmmaK/B_K1>{}, Number<iN>{}, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_thread_buf);
vector_type<FloatAB, WmmaK> a_thread_vec; vector_type<FloatAB, WmmaK> a_thread_vec;
vector_type<FloatAB, WmmaK> b_thread_vec; vector_type<FloatAB, WmmaK> b_thread_vec;
...@@ -323,22 +1041,31 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -323,22 +1041,31 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
}); });
using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type; using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type;
constexpr index_t c_offset = constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0));
c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0));
// debug_hexprinter(0x3c003c00, a_thread_vec.template wmma_gemm.template Run(
AsType<FloatAB>()(Number<0>{})); wmma_gemm.template Run( a_thread_vec.template a_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
AsType<wmma_input_type>()(Number<0>{}), b_thread_vec.template b_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
AsType<wmma_input_type>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<iWmmaK/A_K1>{}, Number<iCut>{}, I0, I0,
Number<iWmmaK%A_K1>{}), a_block_buf, a_thread_desc_, make_tuple(I0, Number<iCut>{}, I0, I0,
I0), a_thread_buf);
}); });
// Run FIFO fashion loopover in Square
// Stage 2: Run FIFO fashion loopover in Square
static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop){ static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop){
// Row Repeatation
static_for<WmmaInnerloop, NRepeat, 1>{}([&](auto iN){ static_for<WmmaInnerloop, NRepeat, 1>{}([&](auto iN){
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<iWmmaK/A_K1>{}, Number<WmmaInnerloop+RepeatDiff>{}, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, Number<WmmaInnerloop+RepeatDiff>{}, I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<iWmmaK/B_K1>{}, Number<iN>{}, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_thread_buf);
vector_type<FloatAB, WmmaK> a_thread_vec; vector_type<FloatAB, WmmaK> a_thread_vec;
vector_type<FloatAB, WmmaK> b_thread_vec; vector_type<FloatAB, WmmaK> b_thread_vec;
...@@ -352,20 +1079,29 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -352,20 +1079,29 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
}); });
using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type; using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop+RepeatDiff, iN, 0)); c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop+RepeatDiff, iN, 0));
wmma_gemm.template Run( wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type>()(Number<0>{}), a_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type>()(Number<0>{}), b_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<iWmmaK/A_K1>{}, // WmmaInnerloop++
Number<WmmaInnerloop+RepeatDiff>{}, I0, I0, Number<iWmmaK%A_K1>{}), a_block_buf, // Col Repeatation
a_thread_desc_, static_for<WmmaInnerloop+1+RepeatDiff, MRepeat, 1>{}([&](auto iM){
make_tuple(I0, Number<WmmaInnerloop+RepeatDiff>{}, I0, I0, I0), a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
a_thread_buf); make_tuple(Number<iWmmaK/A_K1>{}, Number<iM>{}, I0, I0, I0),
static_for<WmmaInnerloop+RepeatDiff, MRepeat, 1>{}([&](auto iM){ a_block_buf,
a_thread_desc_,
make_tuple(I0, Number<iM>{}, I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<iWmmaK/B_K1>{}, Number<WmmaInnerloop>{}, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, Number<WmmaInnerloop>{}, I0, I0, I0),
b_thread_buf);
vector_type<FloatAB, WmmaK> a_thread_vec; vector_type<FloatAB, WmmaK> a_thread_vec;
vector_type<FloatAB, WmmaK> b_thread_vec; vector_type<FloatAB, WmmaK> b_thread_vec;
...@@ -386,54 +1122,6 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -386,54 +1122,6 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
b_thread_vec.template AsType<wmma_input_type>()(Number<0>{}), b_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<iWmmaK/B_K1>{}, Number<WmmaInnerloop>{}, I0,
I0, Number<iWmmaK%B_K1>{}), b_block_buf, b_thread_desc_, make_tuple(I0,
Number<WmmaInnerloop>{}, I0, I0, I0), b_thread_buf);
});
});
*/
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k * WmmaK / A_K1>{}, m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k * WmmaK / B_K1>{}, n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
b_thread_buf);
vector_type<FloatAB, WmmaK> a_thread_vec;
vector_type<FloatAB, WmmaK> b_thread_vec;
static_for<0, WmmaK, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(i / A_K1, 0, 0, 0, i % A_K1))>{}];
b_thread_vec.template AsType<FloatAB>()(i) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(i / B_K1, 0, 0, 0, i % B_K1))>{}];
});
using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run(
a_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
b_thread_vec.template AsType<wmma_input_type>()(Number<0>{}),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
}); });
}); });
} }
...@@ -441,11 +1129,11 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -441,11 +1129,11 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
protected: protected:
// A[M0, M1, M2, K0 = WmmaK] // A[M0, M1, M2, K0 = WmmaK]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<WmmaK / A_K1>{}, I1, I1, I1, Number<A_K1>{})); make_tuple(Number<WmmaK / A_K1>{}, Number<MRepeat>{}, I1, I1, Number<A_K1>{}));
// B[N0, N1, N2, K0 = WmmaK] // B[N0, N1, N2, K0 = WmmaK]
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<WmmaK / B_K1>{}, I1, I1, I1, Number<B_K1>{})); make_tuple(Number<WmmaK / B_K1>{}, Number<NRepeat>{}, I1, I1, Number<B_K1>{}));
// C[M, N, NumRegWMMA] // C[M, N, NumRegWMMA]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
...@@ -503,4 +1191,60 @@ constexpr auto BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_Selector() ...@@ -503,4 +1191,60 @@ constexpr auto BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_Selector()
} }
}; };
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc,
index_t MPerWMMA,
index_t NPerWMMA,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
LoopScheduler LoopSched>
constexpr auto BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop_Selector()
{
if constexpr(LoopSched == LoopScheduler::Default)
{
return BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop<BlockSize,
FloatAB,
FloatAcc,
AK0MK1BlockDesc,
BK0NK1BlockDesc,
MPerWMMA,
NPerWMMA,
MRepeat,
NRepeat,
KPack>{};
}
};
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc,
index_t MPerWMMA,
index_t NPerWMMA,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
LoopScheduler LoopSched>
constexpr auto BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO_Selector()
{
if constexpr(LoopSched == LoopScheduler::Default)
{
return BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize,
FloatAB,
FloatAcc,
AK0MK1BlockDesc,
BK0NK1BlockDesc,
MPerWMMA,
NPerWMMA,
MRepeat,
NRepeat,
KPack>{};
}
};
} // namespace ck } // namespace ck
...@@ -481,7 +481,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -481,7 +481,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1), /* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1),
/* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1), /* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1),
/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, /* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
/* typename DstDimAccessOrder, */ Sequence<1, 0, 2>, /* typename DstDimAccessOrder, */ Sequence<0, 1, 2>,
/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, /* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim,
/* index_t DstVectorDim, */ 2, /* index_t DstVectorDim, */ 2,
/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector, /* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector,
...@@ -513,7 +513,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -513,7 +513,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
decltype(b_grid_desc_k0_n_k1), decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc_k0perblock_nperblock_k1), decltype(b_block_desc_k0perblock_nperblock_k1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>, Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
2, 2,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
......
...@@ -355,5 +355,17 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a, ...@@ -355,5 +355,17 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
c3); c3);
} }
// Ranged input operand
__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a,
half16_t b,
float8_t& c)
{
asm volatile("\n \
v_wmma_f32_16x16x16_f16_w32 %0, %1, %2, %0\n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
}
} // namespace ck } // namespace ck
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment