Unverified Commit 0b547a33 authored by Raman R jana's avatar Raman R jana Committed by GitHub
Browse files

Merge pull request #309 from ramjana/wavelet_model

fixed clang format errors
parents 599497b0 702c3379
...@@ -57,7 +57,7 @@ struct ThreadGroupTensorSliceTransfer_v6r1 ...@@ -57,7 +57,7 @@ struct ThreadGroupTensorSliceTransfer_v6r1
is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{}, is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window"); "wrong! threads should be mapped to cover entire slicing window");
//static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), // static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
// "wrong! ThreadGroup::GetNumOfThread() too small"); // "wrong! ThreadGroup::GetNumOfThread() too small");
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
......
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
#include "gridwise_gemm_xdl_waveletmodel_cshuffle.hpp" #include "gridwise_gemm_xdl_waveletmodel_cshuffle.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -438,7 +437,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle ...@@ -438,7 +437,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config= StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0 #if 0
{ {
......
...@@ -7,8 +7,8 @@ namespace ck { ...@@ -7,8 +7,8 @@ namespace ck {
template <typename TileLoadThreadGroup, index_t NumGemmKPrefetchStage> template <typename TileLoadThreadGroup, index_t NumGemmKPrefetchStage>
struct GridwiseGemmLoadWave; struct GridwiseGemmLoadWave;
//1-stage prefetch // 1-stage prefetch
template<typename TileLoadThreadGroup> template <typename TileLoadThreadGroup>
struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1> struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
{ {
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */)
...@@ -53,21 +53,21 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1> ...@@ -53,21 +53,21 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
//move to 1 // move to 1
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
//LDS write 0 // LDS write 0
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
if constexpr(HasMainLoop) if constexpr(HasMainLoop)
{ {
index_t i=0; index_t i = 0;
do do
{ {
//sync for Load threads() // sync for Load threads()
block_sync_lds(); block_sync_lds();
// global read i + 1 // global read i + 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
...@@ -81,11 +81,10 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1> ...@@ -81,11 +81,10 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
// sync with math threads() // sync with math threads()
block_sync_lds(); block_sync_lds();
//LDS write i+1 // LDS write i+1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
++i; ++i;
} while(i < (num_loop - 1)); } while(i < (num_loop - 1));
} }
...@@ -95,9 +94,7 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1> ...@@ -95,9 +94,7 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
block_sync_lds(); block_sync_lds();
// GEMM num_loop // GEMM num_loop
} }
} }
}; };
...@@ -108,10 +105,7 @@ template <typename TileMathThreadGroup> ...@@ -108,10 +105,7 @@ template <typename TileMathThreadGroup>
struct GridwiseGemmMathWave<TileMathThreadGroup, 1> struct GridwiseGemmMathWave<TileMathThreadGroup, 1>
{ {
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
{
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{ {
...@@ -155,7 +149,6 @@ struct GridwiseGemmMathWave<TileMathThreadGroup, 1> ...@@ -155,7 +149,6 @@ struct GridwiseGemmMathWave<TileMathThreadGroup, 1>
// GEMM num_loop - 1 // GEMM num_loop - 1
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
} }
} }
}; };
......
...@@ -128,42 +128,35 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -128,42 +128,35 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
static constexpr auto AK1 = Number<AK1Value>{}; static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{}; static constexpr auto BK1 = Number<BK1Value>{};
struct TileLoadThreadGroup struct TileLoadThreadGroup
{ {
__device__ static constexpr index_t GetNumOfThread() __device__ static constexpr index_t GetNumOfThread() { return TileLoadThreadGroupSize; }
{
return TileLoadThreadGroupSize;
}
__device__ static constexpr bool IsBelong() __device__ static constexpr bool IsBelong()
{ {
return (get_thread_local_1d_id() >= TileLoadThreadGroupSize); return (get_thread_local_1d_id() >= TileLoadThreadGroupSize);
} }
__device__ static index_t GetThreadId() { return get_thread_local_1d_id() - TileMathThreadGroupSize; } __device__ static index_t GetThreadId()
{
return get_thread_local_1d_id() - TileMathThreadGroupSize;
}
}; };
struct TileMathThreadGroup struct TileMathThreadGroup
{ {
__device__ static constexpr index_t GetNumOfThread() __device__ static constexpr index_t GetNumOfThread() { return TileMathThreadGroupSize; }
{
return TileMathThreadGroupSize;
}
__device__ static constexpr bool IsBelong() __device__ static constexpr bool IsBelong()
{ {
return get_thread_local_1d_id() < TileMathThreadGroupSize; return get_thread_local_1d_id() < TileMathThreadGroupSize;
} }
__device__ static index_t GetThreadId() { return get_thread_local_1d_id(); } __device__ static index_t GetThreadId() { return get_thread_local_1d_id(); }
}; };
using CShuffleBlockTransferThreadGroup = using CShuffleBlockTransferThreadGroup = ThisThreadBlock<TileMathThreadGroupSize>;
ThisThreadBlock<TileMathThreadGroupSize>; // load and math+store Wave pipelines.
//load and math+store Wave pipelines. // TODO: build pipelines blocks scheduling parallel tasks
//TODO: build pipelines blocks scheduling parallel tasks using GridwiseGemmLoad = GridwiseGemmLoadWave<TileLoadThreadGroup, NumGemmKPrefetchStage>;
using GridwiseGemmLoad = GridwiseGemmLoadWave<TileLoadThreadGroup,NumGemmKPrefetchStage>; using GridwiseGemmMath = GridwiseGemmMathWave<TileMathThreadGroup, NumGemmKPrefetchStage>;
using GridwiseGemmMath = GridwiseGemmMathWave<TileMathThreadGroup,NumGemmKPrefetchStage>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
...@@ -177,8 +170,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -177,8 +170,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
{ {
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(BK0,Number<NPerBlock>{},BK1), make_tuple(BK0, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock+BBlockLdsExtraN>{} * BK1, BK1, I1)); make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -360,7 +353,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -360,7 +353,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment // lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1); constexpr auto max_lds_align = math::lcm(AK1, BK1);
...@@ -392,10 +384,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -392,10 +384,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
const index_t n_block_data_idx_on_grid = const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
if (TileLoadThreadGroup::IsBelong()) if(TileLoadThreadGroup::IsBelong())
{ {
//LoadWave // LoadWave
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -463,7 +455,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -463,7 +455,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
GridwiseGemmLoad::template RunLoadWavePipeline<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1, GridwiseGemmLoad::template RunLoadWavePipeline<HasMainKBlockLoop>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
a_blockwise_copy, a_blockwise_copy,
a_grid_buf, a_grid_buf,
...@@ -477,20 +470,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -477,20 +470,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
b_block_slice_copy_step, b_block_slice_copy_step,
num_k_block_main_loop); num_k_block_main_loop);
block_sync_lds(); block_sync_lds();
block_sync_lds(); block_sync_lds();
} }
else if (TileMathThreadGroup::IsBelong()) else if(TileMathThreadGroup::IsBelong())
{ {
//branch early for math wave // branch early for math wave
constexpr index_t KPack = math::max( constexpr index_t KPack =
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); math::max(math::lcm(AK1, BK1),
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<TileMathThreadGroupSize, TileMathThreadGroupSize,
FloatAB, FloatAB,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
...@@ -506,11 +498,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -506,11 +498,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// TODO re-architect LDS+math stages // TODO re-architect LDS+math stages
GridwiseGemmMath::template RunMathWavePipeline<HasMainKBlockLoop>(a_block_buf, GridwiseGemmMath::template RunMathWavePipeline<HasMainKBlockLoop>(
b_block_buf, a_block_buf, b_block_buf, blockwise_gemm, c_thread_buf, num_k_block_main_loop);
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -570,8 +559,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -570,8 +559,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
N1, // N1 = NWave N1, // N1 = NWave
N2))), // N2 = NPerXdl N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple( make_tuple(Sequence<>{},
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); Sequence<0, 2, 4, 5, 6>{},
Sequence<>{},
Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
...@@ -602,8 +593,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -602,8 +593,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
make_multi_index(n_thread_data_on_block)); make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS // shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds = auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
ThreadwiseTensorSliceTransfer_v1r3<FloatGemmAcc, FloatGemmAcc,
FloatCShuffle, FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
...@@ -621,8 +612,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -621,8 +612,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
1, 1,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
true>{ true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0, make_multi_index(0,
0, 0,
m_thread_data_on_block_idx[I1], m_thread_data_on_block_idx[I1],
...@@ -683,21 +673,24 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -683,21 +673,24 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
//TODO // TODO
// 1. we do not need to do LDS swizzle to align global writes writing cache lines // 1. we do not need to do LDS swizzle to align global writes writing cache
// v_mfma cmat, amat, bmat, cmat - c-mat register layout are 1xN elments (N is vertical or strided dimension) // lines
// v_mfma cmat, bmat, amat, cmat - c-mat register layout are Mx1 elments (M is coalescing dimension) // v_mfma cmat, amat, bmat, cmat - c-mat register layout are 1xN
// by enumerating M index in amat, bmat you can align cmat register(s) to contiguous M elements // elments (N is vertical or strided dimension) v_mfma cmat, bmat, amat,
// for example // cmat - c-mat register layout are Mx1 elments (M is coalescing
// dimension) by enumerating M index in amat, bmat you can align cmat
// register(s) to contiguous M elements for example
// 1st mfma instruction output space : 0 4 8 12 16 .... // 1st mfma instruction output space : 0 4 8 12 16 ....
// 2nd mfma instruction output space : 1 5 9 13 17 .... // 2nd mfma instruction output space : 1 5 9 13 17 ....
// 3rd mfma instruction output space : 2 6 10 14 18 .... // 3rd mfma instruction output space : 2 6 10 14 18 ....
// 4th mfma instruction output space : 3 7 11 15 19 .... // 4th mfma instruction output space : 3 7 11 15 19 ....
// you can pack 4 registers output space into 2WORD and do global write (no LDS swizzling required) // you can pack 4 registers output space into 2WORD and do global write
// 2. avoid using s_barrier in this case where not all 256 threads required to swizzle c layout // (no LDS swizzling required)
// 2. avoid using s_barrier in this case where not all 256 threads required to
// swizzle c layout
static_for<0, num_access, 1>{}([&](auto access_id) { static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS // make sure it's safe to write to LDS
...@@ -732,4 +725,4 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -732,4 +725,4 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
} }
} }
}; // GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle }; // GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
} //namespace ck } // namespace ck
...@@ -249,8 +249,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -249,8 +249,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
}(); }();
using BlockwiseGemm = using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize FloatAB,
FloatAB,
FloatAcc, FloatAcc,
decltype(a_k0_m_k1_block_desc), decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc), decltype(b_k0_n_k1_block_desc),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment