Commit 702c3379 authored by root's avatar root
Browse files

fixed clang format errors

parent 599497b0
......@@ -57,7 +57,7 @@ struct ThreadGroupTensorSliceTransfer_v6r1
is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
//static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
// static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
// "wrong! ThreadGroup::GetNumOfThread() too small");
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
......
......@@ -10,7 +10,6 @@
#include "gridwise_gemm_xdl_waveletmodel_cshuffle.hpp"
#include "gemm_specialization.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
......@@ -438,7 +437,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config= StreamConfig{})
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
#if 0
{
......
......@@ -7,8 +7,8 @@ namespace ck {
template <typename TileLoadThreadGroup, index_t NumGemmKPrefetchStage>
struct GridwiseGemmLoadWave;
//1-stage prefetch
template<typename TileLoadThreadGroup>
// 1-stage prefetch
template <typename TileLoadThreadGroup>
struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
{
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */)
......@@ -53,21 +53,21 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
//move to 1
// move to 1
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
//LDS write 0
// LDS write 0
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
if constexpr(HasMainLoop)
{
index_t i=0;
index_t i = 0;
do
{
//sync for Load threads()
// sync for Load threads()
block_sync_lds();
// global read i + 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
......@@ -81,11 +81,10 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
// sync with math threads()
block_sync_lds();
//LDS write i+1
// LDS write i+1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
++i;
} while(i < (num_loop - 1));
}
......@@ -95,9 +94,7 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
block_sync_lds();
// GEMM num_loop
}
}
};
......@@ -108,10 +105,7 @@ template <typename TileMathThreadGroup>
struct GridwiseGemmMathWave<TileMathThreadGroup, 1>
{
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */)
{
return true;
}
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
......@@ -155,7 +149,6 @@ struct GridwiseGemmMathWave<TileMathThreadGroup, 1>
// GEMM num_loop - 1
block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
};
......
......@@ -128,42 +128,35 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
struct TileLoadThreadGroup
{
__device__ static constexpr index_t GetNumOfThread()
{
return TileLoadThreadGroupSize;
}
__device__ static constexpr index_t GetNumOfThread() { return TileLoadThreadGroupSize; }
__device__ static constexpr bool IsBelong()
{
return (get_thread_local_1d_id() >= TileLoadThreadGroupSize);
}
__device__ static index_t GetThreadId() { return get_thread_local_1d_id() - TileMathThreadGroupSize; }
__device__ static index_t GetThreadId()
{
return get_thread_local_1d_id() - TileMathThreadGroupSize;
}
};
struct TileMathThreadGroup
{
__device__ static constexpr index_t GetNumOfThread()
{
return TileMathThreadGroupSize;
}
__device__ static constexpr index_t GetNumOfThread() { return TileMathThreadGroupSize; }
__device__ static constexpr bool IsBelong()
{
return get_thread_local_1d_id() < TileMathThreadGroupSize;
}
__device__ static index_t GetThreadId() { return get_thread_local_1d_id(); }
};
using CShuffleBlockTransferThreadGroup =
ThisThreadBlock<TileMathThreadGroupSize>;
//load and math+store Wave pipelines.
//TODO: build pipelines blocks scheduling parallel tasks
using GridwiseGemmLoad = GridwiseGemmLoadWave<TileLoadThreadGroup,NumGemmKPrefetchStage>;
using GridwiseGemmMath = GridwiseGemmMathWave<TileMathThreadGroup,NumGemmKPrefetchStage>;
using CShuffleBlockTransferThreadGroup = ThisThreadBlock<TileMathThreadGroupSize>;
// load and math+store Wave pipelines.
// TODO: build pipelines blocks scheduling parallel tasks
using GridwiseGemmLoad = GridwiseGemmLoadWave<TileLoadThreadGroup, NumGemmKPrefetchStage>;
using GridwiseGemmMath = GridwiseGemmMathWave<TileMathThreadGroup, NumGemmKPrefetchStage>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
......@@ -177,8 +170,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0,Number<NPerBlock>{},BK1),
make_tuple(Number<NPerBlock+BBlockLdsExtraN>{} * BK1, BK1, I1));
make_tuple(BK0, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
__host__ __device__ static constexpr auto
......@@ -360,7 +353,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
......@@ -392,10 +384,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
if (TileLoadThreadGroup::IsBelong())
if(TileLoadThreadGroup::IsBelong())
{
//LoadWave
// LoadWave
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -463,7 +455,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
GridwiseGemmLoad::template RunLoadWavePipeline<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
GridwiseGemmLoad::template RunLoadWavePipeline<HasMainKBlockLoop>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
......@@ -477,20 +470,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
b_block_slice_copy_step,
num_k_block_main_loop);
block_sync_lds();
block_sync_lds();
}
else if (TileMathThreadGroup::IsBelong())
else if(TileMathThreadGroup::IsBelong())
{
//branch early for math wave
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
// branch early for math wave
constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<TileMathThreadGroupSize,
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<
TileMathThreadGroupSize,
FloatAB,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
......@@ -506,11 +498,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// TODO re-architect LDS+math stages
GridwiseGemmMath::template RunMathWavePipeline<HasMainKBlockLoop>(a_block_buf,
b_block_buf,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
GridwiseGemmMath::template RunMathWavePipeline<HasMainKBlockLoop>(
a_block_buf, b_block_buf, blockwise_gemm, c_thread_buf, num_k_block_main_loop);
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
......@@ -570,8 +559,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
make_tuple(Sequence<>{},
Sequence<0, 2, 4, 5, 6>{},
Sequence<>{},
Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
......@@ -602,8 +593,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatGemmAcc,
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
......@@ -621,8 +612,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
......@@ -683,21 +673,24 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
//TODO
// 1. we do not need to do LDS swizzle to align global writes writing cache lines
// v_mfma cmat, amat, bmat, cmat - c-mat register layout are 1xN elments (N is vertical or strided dimension)
// v_mfma cmat, bmat, amat, cmat - c-mat register layout are Mx1 elments (M is coalescing dimension)
// by enumerating M index in amat, bmat you can align cmat register(s) to contiguous M elements
// for example
// TODO
// 1. we do not need to do LDS swizzle to align global writes writing cache
// lines
// v_mfma cmat, amat, bmat, cmat - c-mat register layout are 1xN
// elments (N is vertical or strided dimension) v_mfma cmat, bmat, amat,
// cmat - c-mat register layout are Mx1 elments (M is coalescing
// dimension) by enumerating M index in amat, bmat you can align cmat
// register(s) to contiguous M elements for example
// 1st mfma instruction output space : 0 4 8 12 16 ....
// 2nd mfma instruction output space : 1 5 9 13 17 ....
// 3rd mfma instruction output space : 2 6 10 14 18 ....
// 4th mfma instruction output space : 3 7 11 15 19 ....
// you can pack 4 registers output space into 2WORD and do global write (no LDS swizzling required)
// 2. avoid using s_barrier in this case where not all 256 threads required to swizzle c layout
// you can pack 4 registers output space into 2WORD and do global write
// (no LDS swizzling required)
// 2. avoid using s_barrier in this case where not all 256 threads required to
// swizzle c layout
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
......@@ -732,4 +725,4 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
}
}
}; // GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
} //namespace ck
} // namespace ck
......@@ -249,8 +249,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
}();
using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize
FloatAB,
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment