Commit 7d0a5412 authored by root's avatar root
Browse files

threadwise transfer

parent b3a012bc
...@@ -9,27 +9,27 @@ namespace ck { ...@@ -9,27 +9,27 @@ namespace ck {
// blockwise GEMM: C[M, N] += transpose(A[K, M]) * B[K, N] // blockwise GEMM: C[M, N] += transpose(A[K, M]) * B[K, N]
// A and B are visable to the whole block, C is distributed among each thread // A and B are visable to the whole block, C is distributed among each thread
// If following number are power of 2, index calculation shall be greatly reduced: // If following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster, // KPerThread, HPerThread, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster // MLevel1ThreadCluster, NLevel1ThreadCluster
template <index_t BlockSize, template <index_t BlockSize,
typename BlockMatrixA, typename BlockMatrixA,
typename BlockMatrixB, typename BlockMatrixB,
typename ThreadMatrixC, typename ThreadMatrixC,
index_t MPerThreadSubC, index_t KPerThread,
index_t NPerThreadSubC, index_t HPerThread,
index_t KPerThreadLoop, index_t WPerThread,
index_t MLevel0ThreadCluster, index_t CYXPerThreadLoop,
index_t NLevel0ThreadCluster, index_t HThreadCluster,
index_t MLevel1ThreadCluster, index_t WThreadCluster,
index_t NLevel1ThreadCluster, index_t ThreadGemmADataPerRead_K,
index_t ThreadGemmADataPerRead_M, index_t ThreadGemmBDataPerRead_W>
index_t ThreadGemmBDataPerRead_N>
struct BlockwiseGemm_km_kn_m0m1n0n1_v3 struct BlockwiseGemm_km_kn_m0m1n0n1_v3
{ {
struct MatrixIndex struct MatrixIndex
{ {
index_t row; index_t k;
index_t col; index_t h;
index_t w;
}; };
index_t mMyThreadOffsetA; index_t mMyThreadOffsetA;
...@@ -44,325 +44,153 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -44,325 +44,153 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr index_t ThreadPerLevel1Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster * // constexpr index_t ThreadPerLevel1Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster *
MLevel1ThreadCluster * NLevel1ThreadCluster; // MLevel1ThreadCluster * NLevel1ThreadCluster;
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n"); static_assert(BlockSize == HThreadCluster * WThreadCluster, "wrong! wrong blocksize\n");
static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0), static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0),
"wrong! K dimension not consistent\n"); "wrong! K dimension not consistent\n");
constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed constexpr index_t K = BlockMatrixA{}.GetLength(I1); // A is transposed
constexpr index_t N = BlockMatrixB{}.GetLength(I1); constexpr index_t N = BlockMatrixB{}.GetLength(I1);
constexpr index_t H = BlockMatrixB{}.GetLength(I2);
constexpr index_t W = BlockMatrixB{}.GetLength(I3);
static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 && static_assert(
N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0, K % (KPerThread) == 0 &&
"wrong! Cannot evenly divide work among\n"); (N * H * W) % (HPerThread * WPerThread * HThreadCluster * WThreadCluster) == 0,
"wrong! Cannot evenly divide work among\n");
static_assert(ThreadMatrixC{}.GetLength(I0) == GetThreadMatrixCLengths()[I0] &&
ThreadMatrixC{}.GetLength(I1) == GetThreadMatrixCLengths()[I1],
"wrong! ThreadMatrixC lengths is wrong");
auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
mMyThreadOffsetA = BlockMatrixA{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.row)); mMyThreadOffsetA = BlockMatrixA{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.k));
mMyThreadOffsetB = BlockMatrixB{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.col)); mMyThreadOffsetB = BlockMatrixB{}.CalculateOffset(
make_tuple(0, 0, c_thread_mtx_index.h, c_thread_mtx_index.w));
} }
__device__ static constexpr auto GetThreadMatrixCLengths() __device__ static constexpr auto GetThreadMatrixCLengths()
{ {
constexpr auto I1 = Number<1>{}; return Sequence<KPerThread, 1, HPerThread, WPerThread>{};
constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed
constexpr index_t N = BlockMatrixB{}.GetLength(I1);
constexpr index_t MRepeat =
M / (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster);
constexpr index_t NRepeat =
N / (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster);
return Sequence<MRepeat * MPerThreadSubC, NRepeat * NPerThreadSubC>{};
} }
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) __device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
{ {
constexpr index_t ThreadPerLevel0Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster; return MatrixIndex{1, 8, 8};
index_t level1_id = thread_id / ThreadPerLevel0Cluster;
index_t level1_m_id = level1_id / NLevel1ThreadCluster;
index_t level1_n_id = level1_id % NLevel1ThreadCluster;
index_t level0_id = thread_id % ThreadPerLevel0Cluster;
index_t level0_m_id = level0_id / NLevel0ThreadCluster;
index_t level0_n_id = level0_id % NLevel0ThreadCluster;
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0ThreadCluster;
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0ThreadCluster;
return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
} }
template <typename FloatA, typename FloatB, typename FloatC> template <typename SrcDesc,
__device__ void typename DstDesc,
Run_naive(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const index_t NSliceRow,
index_t NSliceCol,
index_t DataPerAccess>
struct ThreadwiseSliceCopy_a
{ {
constexpr auto I0 = Number<0>{}; template <typename Data>
constexpr auto I1 = Number<1>{}; __device__ static void Run(const Data* p_src, Data* p_dst)
{
constexpr auto a_block_mtx = BlockMatrixA{}; static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
constexpr auto b_block_mtx = BlockMatrixB{}; "wrong! Desc should be known at compile-time");
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr auto K = a_block_mtx.GetLength(I0);
constexpr auto MPerThread = c_thread_mtx.GetLength(I0);
constexpr auto NPerThread = c_thread_mtx.GetLength(I1);
constexpr index_t MPerLevel1Cluster =
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
constexpr index_t NPerLevel1Cluster =
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
// thread A, B for GEMM
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<MPerThread>{}));
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<NPerThread>{}));
FloatA p_a_thread[a_thread_mtx.GetElementSpaceSize()]; using vector_t = typename vector_type<Data, DataPerAccess>::type;
FloatB p_b_thread[b_thread_mtx.GetElementSpaceSize()];
constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy_v3<BlockMatrixA, static_for<0, NSliceRow, 1>{}([&](auto i) {
decltype(a_thread_mtx), static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) {
KPerThreadLoop, constexpr auto src_offset = SrcDesc{}.CalculateOffset(make_tuple(i, j));
MPerThreadSubC, constexpr auto dst_offset = DstDesc{}.CalculateOffset(make_tuple(i, j));
ThreadGemmADataPerRead_M>{};
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy_v3<BlockMatrixB, *reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
decltype(b_thread_mtx), *reinterpret_cast<const vector_t*>(&p_src[src_offset]);
KPerThreadLoop, });
NPerThreadSubC, });
ThreadGemmBDataPerRead_N>{}; }
};
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_mtx), template <typename SrcDesc,
decltype(b_thread_mtx), typename DstDesc,
decltype(c_thread_mtx)>{}; index_t NSliceCYX,
#pragma unroll index_t NSliceH,
// loop over k index_t NSliceW,
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) index_t DataPerAccess>
struct ThreadwiseSliceCopy_b
{
template <typename Data>
__device__ static void Run(const Data* p_src, Data* p_dst)
{ {
#pragma unroll static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
// read A "wrong! Desc should be known at compile-time");
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ using vector_t = typename vector_type<Data, DataPerAccess>::type;
a_thread_copy.Run(p_a_block +
a_block_mtx.CalculateOffset( static_for<0, NSliceCYX, 1>{}([&](auto i) {
make_tuple(k_begin, m_repeat * MPerLevel1Cluster)) + static_for<0, NSliceH, 1>{}([&](auto j) {
mMyThreadOffsetA, static_for<0, NSliceW, 1>{}([&](auto k) {
p_a_thread + a_thread_mtx.CalculateOffset( constexpr auto src_offset =
make_tuple(0, m_repeat * MPerThreadSubC))); SrcDesc{}.CalculateOffset(make_tuple(i, 0, j, k));
} constexpr auto dst_offset =
DstDesc{}.CalculateOffset(make_tuple(i, 0, j, k));
#pragma unroll
// read B *reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) *reinterpret_cast<const vector_t*>(&p_src[src_offset]);
{ });
b_thread_copy.Run(p_b_block + });
b_block_mtx.CalculateOffset( });
make_tuple(k_begin, n_repeat * NPerLevel1Cluster)) +
mMyThreadOffsetB,
p_b_thread + b_thread_mtx.CalculateOffset(
make_tuple(0, n_repeat * NPerThreadSubC)));
}
// C += A * B
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
} }
} };
template <typename FloatA, typename FloatB, typename FloatC> template <typename FloatA, typename FloatB, typename FloatC>
__device__ void __device__ void
Run_pipelined_2x2(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const Run_naive(const FloatA* p_a_block, const FloatB* p_b_thread, FloatC* p_c_thread) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto a_block_mtx = BlockMatrixA{}; constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr auto K = a_block_mtx.GetLength(I0);
constexpr auto MPerThread = c_thread_mtx.GetLength(I0);
constexpr auto NPerThread = c_thread_mtx.GetLength(I1);
constexpr index_t MPerLevel1Cluster =
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
constexpr index_t NPerLevel1Cluster = constexpr auto CYXPerBlock = a_block_mtx.GetLength(I0);
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC; // thread A, B for GEMM
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
static_assert(MRepeat == 2 && NRepeat == 2,
"wrong! inline asm cannot deal with this GEMM config yet");
// thread A, B
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2( constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<MPerThread>{})); make_tuple(Number<CYXPerThreadLoop>{}, Number<KPerThread>{}));
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2( constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<NPerThread>{})); make_tuple(Number<CYXPerThreadLoop>{}, Number<1>{}, Number<1>{}, Number<1>{}));
// thread A-sub, B-sub
constexpr auto a_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}),
make_tuple(Number<MPerThread>{}, Number<1>{}));
constexpr auto b_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2( constexpr auto c_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}), make_tuple(Number<KPerThread>{}, Number<1>{}));
make_tuple(Number<NPerThread>{}, Number<1>{}));
constexpr auto c_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}),
make_tuple(Number<NPerThread>{}, Number<1>{}));
FloatA p_a_thread[a_thread_mtx.GetElementSpaceSize()]; FloatA p_a_thread[a_thread_mtx.GetElementSpaceSize()];
FloatB p_b_thread[b_thread_mtx.GetElementSpaceSize()];
constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy_v3<BlockMatrixA, constexpr auto a_thread_copy = ThreadwiseSliceCopy_a<BlockMatrixA,
decltype(a_thread_mtx), decltype(a_thread_mtx),
KPerThreadLoop, CYXPerThreadLoop,
MPerThreadSubC, KPerThread,
ThreadGemmADataPerRead_M>{}; ThreadGemmADataPerRead_K>{};
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy_v3<BlockMatrixB, constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3<decltype(a_thread_mtx),
decltype(b_thread_mtx), decltype(b_thread_mtx),
KPerThreadLoop, decltype(c_thread_mtx)>{};
NPerThreadSubC, // loop over k
ThreadGemmBDataPerRead_N>{}; for(index_t cyx_begin = 0; cyx_begin < CYXPerBlock; cyx_begin += CYXPerThreadLoop)
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_sub_mtx),
decltype(b_thread_sub_mtx),
decltype(c_thread_sub_mtx)>{};
const FloatA* p_a_block_off = p_a_block + mMyThreadOffsetA;
const FloatB* p_b_block_off = p_b_block + mMyThreadOffsetB;
// read A_sub_0
a_thread_copy.Run(p_a_block_off, p_a_thread);
// read B_sub_0
b_thread_copy.Run(p_b_block_off, p_b_thread);
// read B_sub_1
b_thread_copy.Run(p_b_block_off +
b_block_mtx.CalculateOffset(make_tuple(0, NPerLevel1Cluster)),
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
// read A_sub_1
a_thread_copy.Run(p_a_block_off +
a_block_mtx.CalculateOffset(make_tuple(0, MPerLevel1Cluster)),
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(
p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
#pragma unroll
// loop over rest of k
for(index_t k = KPerThreadLoop; k < K; k += KPerThreadLoop)
{ {
// read A_sub_0 a_thread_copy.Run(p_a_block + a_block_mtx.CalculateOffset(make_tuple(cyx_begin, 0)) +
a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(make_tuple(k, 0)), mMyThreadOffsetA,
p_a_thread); p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, 0)));
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread,
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, 0)));
// read B_sub_0
b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(make_tuple(k, 0)),
p_b_thread);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread +
c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC)));
// read B_sub_1
b_thread_copy.Run(
p_b_block_off + b_block_mtx.CalculateOffset(make_tuple(k, NPerLevel1Cluster)),
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
// read A_sub_1
a_thread_copy.Run(
p_a_block_off + a_block_mtx.CalculateOffset(make_tuple(k, MPerLevel1Cluster)),
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(
p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
}
// C_sub_10 += transpose(A_sub_1) * B_sub_0 // threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
threadwise_gemm.Run( }
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread,
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, 0)));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC)));
} }
template <typename FloatA, typename FloatB, typename FloatC> template <typename FloatA, typename FloatB, typename FloatC>
__device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const __device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{ {
#if 0
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr index_t MPerThread = ThreadMatrixC{}.GetLength(I0);
constexpr index_t NPerThread = ThreadMatrixC{}.GetLength(I1);
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
if constexpr(MRepeat == 2 && NRepeat == 2)
{
Run_pipelined_2x2(p_a_block, p_b_block, p_c_thread);
}
else
{
Run_naive(p_a_block, p_b_block, p_c_thread);
}
#else
Run_naive(p_a_block, p_b_block, p_c_thread); Run_naive(p_a_block, p_b_block, p_c_thread);
#endif
} }
}; };
......
...@@ -18,12 +18,12 @@ template <index_t BlockSize, ...@@ -18,12 +18,12 @@ template <index_t BlockSize,
typename AGlobalDesc, typename AGlobalDesc,
typename BGlobalDesc, typename BGlobalDesc,
typename CGlobalDesc, typename CGlobalDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t MPerThread, index_t HWPerBlock,
index_t NPerThread, index_t CYXPerBlock,
index_t KPerThread, index_t KPerThread,
index_t HWPerThread,
index_t CYXPerThread,
index_t MLevel0Cluster, index_t MLevel0Cluster,
index_t NLevel0Cluster, index_t NLevel0Cluster,
index_t MLevel1Cluster, index_t MLevel1Cluster,
...@@ -58,31 +58,34 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -58,31 +58,34 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
{ {
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{}, Number<BBlockTransferDstScalarPerVector_N>{},
Number<MPerThread>{}, Number<KPerThread>{},
Number<NPerThread>{}); Number<HWPerThread>{});
static_assert(CYXPerBlock == 4 && HWPerBlock == 64 && KPerBlock == 16, "");
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_cyx_k_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align); make_tuple(Number<CYXPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_cyx_n_h_w_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_cyx_n_h_w_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<1>{}, Number<8>{}, Number<8>{}), max_lds_align); make_tuple(Number<CYXPerBlock>{}, Number<1>{}, Number<8>{}, Number<8>{}),
max_lds_align);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_cyx_k_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size = constexpr auto b_block_space_size = math::integer_least_multiple(
math::integer_least_multiple(b_cyx_n_h_w_block_desc.GetElementSpaceSize(), max_lds_align); b_cyx_n_h_w_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float); return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float);
} }
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc, __device__ void Run(const AGlobalDesc& a_cyx_k_global_desc,
const Float* __restrict__ p_a_global, const Float* __restrict__ p_a_global,
const BGlobalDesc& b_cyx_n_h_w_global_desc, const BGlobalDesc& b_cyx_n_h_w_global_desc,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
...@@ -94,62 +97,70 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -94,62 +97,70 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto CYX = a_cyx_k_global_desc.GetLength(I0);
const auto K = a_cyx_k_global_desc.GetLength(I1);
static_assert(CYX == 4 * 3 * 3 && K == 16, "");
const auto K = a_k_m_global_desc.GetLength(I0);
const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_cyx_n_h_w_global_desc.GetLength(I1); const auto N = b_cyx_n_h_w_global_desc.GetLength(I1);
const auto H = b_cyx_n_h_w_global_desc.GetLength(I2);
const auto W = b_cyx_n_h_w_global_desc.GetLength(I3);
// divide block work by [M, N] // divide block work by [M, N]
#if 0 #if 1
const auto m_block_work_num = M / Number<MPerBlock>{}; const auto m_block_work_num = K / Number<KPerBlock>{};
const auto n_block_work_num = N / Number<NPerBlock>{}; const auto nhw_block_work_num = (N * H * W) / Number<HWPerBlock>{};
const index_t m_block_work_id = get_block_1d_id() / n_block_work_num; const index_t k_block_work_id = get_block_1d_id() / nhw_block_work_num;
const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num; const index_t nhw_block_work_id = get_block_1d_id() - k_block_work_id * nhw_block_work_num;
#else #else
// Hack: this force result into SGPR // Hack: this force result into SGPR
const index_t m_block_work_num = __builtin_amdgcn_readfirstlane(M / MPerBlock); const index_t m_block_work_num = __builtin_amdgcn_readfirstlane(K / KPerBlock);
const index_t n_block_work_num = __builtin_amdgcn_readfirstlane(N / NPerBlock); const index_t nhw_block_work_num = __builtin_amdgcn_readfirstlane(N / HWPerBlock);
const index_t m_block_work_id = const index_t k_block_work_id =
__builtin_amdgcn_readfirstlane(get_block_1d_id() / n_block_work_num); __builtin_amdgcn_readfirstlane(get_block_1d_id() / nhw_block_work_num);
const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num; const index_t nhw_block_work_id = get_block_1d_id() - k_block_work_id * nhw_block_work_num;
#endif #endif
const index_t m_block_data_on_global = m_block_work_id * MPerBlock; const index_t m_block_data_on_global = k_block_work_id * KPerBlock;
const index_t h_block_data_on_global = n_block_work_id * 8; const index_t h_block_data_on_global = nhw_block_work_id * 8;
const index_t w_block_data_on_global = n_block_work_id * 8; const index_t w_block_data_on_global = nhw_block_work_id * 8;
// lds max alignment // lds max alignment
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{}, Number<BBlockTransferDstScalarPerVector_N>{},
Number<MPerThread>{}, Number<KPerThread>{},
Number<NPerThread>{}); Number<HWPerThread>{});
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_cyx_k_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align); make_tuple(Number<CYXPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_cyx_n_h_w_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_cyx_n_h_w_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<1>{}, Number<8>{}, Number<8>{}), max_lds_align); make_tuple(Number<CYXPerBlock>{}, Number<1>{}, Number<8>{}, Number<8>{}),
max_lds_align);
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize, BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
Sequence<KPerBlock, MPerBlock>, Sequence<CYXPerBlock, KPerBlock>,
ABlockTransferThreadSliceLengths_K_M, ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M, ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
Float, Float,
Float, Float,
decltype(a_k_m_global_desc), decltype(a_cyx_k_global_desc),
decltype(a_k_m_block_desc), decltype(a_cyx_k_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<0, 1>, Sequence<0, 1>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
...@@ -162,101 +173,65 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -162,101 +173,65 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
a_k_m_global_desc, a_cyx_k_global_desc,
make_multi_index(0, m_block_data_on_global), make_multi_index(0, m_block_data_on_global),
a_k_m_block_desc, a_cyx_k_block_desc,
make_multi_index(0, 0)); make_multi_index(0, 0));
// B matrix blockwise copy #if 1
auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4< constexpr auto b_cyx_n_h_w_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
BlockSize, make_tuple(Number<CYXPerThread>{}, Number<1>{}, Number<1>{}, Number<1>{}));
InMemoryDataOperation::Set,
Sequence<KPerBlock, 1, 8, 8>, // BlockSliceLengths const index_t h_thread_id = get_thread_local_1d_id() / 8;
Sequence<KPerBlock, 1, 1, 1>, // ThreadSliceLengths_K_N const index_t w_thread_id = get_thread_local_1d_id() % 8;
Sequence<1, 1, 8, 8>, // ThreadClusterLengths_K_N
Sequence<3, 2, 0, 1>, // ThreadClusterArrangeOrder auto b_threadwise_transfer = ThreadwiseDynamicTensorSliceTransfer_v2<
Float, Float,
Float, Float,
decltype(b_cyx_n_h_w_global_desc), // SrcDesc decltype(b_cyx_n_h_w_global_desc),
decltype(b_cyx_n_h_w_block_desc), // DstDesc decltype(b_cyx_n_h_w_thread_desc),
Sequence<3, 2, 0, 1>, // SrcDimAccessOrder Sequence<CYXPerThread, 1, 1, 1>,
Sequence<3, 2, 0, 1>, // DstDimAccessOrder Sequence<3, 2, 0, 1>, // BBlockTransferSrcAccessOrder,
3, // SrcVectorDim 3, // BBlockTransferSrcVectorDim,
3, // DstVectorDim 1, // BBlockTransferSrcScalarPerVector,
1, // SrcScalarPerVector
1, // DstScalarPerVector
AddressSpace::Global, AddressSpace::Global,
AddressSpace::Lds, AddressSpace::Vgpr,
1, InMemoryDataOperation::Set,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, true>(
true>(b_cyx_n_h_w_global_desc, b_cyx_n_h_w_global_desc,
make_multi_index(0, 0, h_block_data_on_global, w_block_data_on_global), make_multi_index(
b_cyx_n_h_w_block_desc, 0, 0, h_block_data_on_global + h_thread_id, w_block_data_on_global + w_thread_id));
make_multi_index(0, 0, 0, 0));
#if 0
constexpr auto b_cyx_n_h_w_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThread>{}, Number<NPerThread>{}));
using BThreadwiseTransfer =
ThreadwiseDynamicTensorSliceTransfer_v2<Float,
Float,
decltype(b_cyx_n_h_w_global_desc),
decltype(b_cyx_n_h_w_thread_desc),
Sequence<KPerThread, NPerThread>,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
AddressSpace::Global,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
1,
true>;
#endif #endif
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
//static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
//NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
//"wrong!");
// constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
// constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
// c_thread_mtx definition: this is a mess // c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx // TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k_n_h_w_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2( constexpr auto c_k_n_h_w_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MPerThread>{}, Number<1>{}, Number<1>{}, Number<1>{})); make_tuple(Number<KPerThread>{}, Number<1>{}, Number<1>{}, Number<1>{}));
#if 0 #if 1
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize, BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
decltype(a_k_m_block_desc), decltype(a_cyx_k_block_desc),
decltype(b_cyx_n_h_w_block_desc), decltype(b_cyx_n_h_w_block_desc),
decltype(c_k_n_h_w_thread_desc), decltype(c_k_n_h_w_thread_desc),
MPerThread, 16, // KPerThreadSubC
NPerThread, 1, // HPerThreadSubC
KPerThread, 1, // WPerThreadSubC
MLevel0Cluster, 1, // CYXPerThreadLoop
NLevel0Cluster, 8, // HThreadCluster
MLevel1Cluster, 8, // WThreadCluster
NLevel1Cluster, 1, // ThreadGemmADataPerRead_K
1, 1 // ThreadGemmBDataPerRead_W
1>{}; >{};
#endif #endif
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_cyx_k_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size = constexpr auto b_block_space_size = math::integer_least_multiple(
math::integer_least_multiple(b_cyx_n_h_w_block_desc.GetElementSpaceSize(), max_lds_align); b_cyx_n_h_w_block_desc.GetElementSpaceSize(), max_lds_align);
Float* p_a_block_double = p_shared_block; Float* p_a_block_double = p_shared_block;
Float* p_b_block_double = p_shared_block + 2 * a_block_space_size; Float* p_b_block_double = p_shared_block + 2 * a_block_space_size;
...@@ -272,11 +247,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -272,11 +247,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// zero out threadwise output // zero out threadwise output
// threadwise_matrix_set_zero_v2(c_k_n_h_w_thread_desc, p_c_thread); // threadwise_matrix_set_zero_v2(c_k_n_h_w_thread_desc, p_c_thread);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto a_block_slice_copy_step = make_multi_index(CYXPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(CYXPerBlock, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{}; constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_cyx_n_h_w_global_iterator_hacks = BGlobalIteratorHacks{}; constexpr auto b_cyx_n_h_w_global_iterator_hacks = BGlobalIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for // hack to control index calculation when move slice window for A and B matrix for
...@@ -288,13 +263,25 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -288,13 +263,25 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_blockwise_copy.RunRead(a_cyx_k_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_cyx_n_h_w_global_desc, p_b_global, b_cyx_n_h_w_global_iterator_hacks);
constexpr auto b_thread_mtx = b_cyx_n_h_w_thread_desc;
Float p_b_thread[b_thread_mtx.GetElementSpaceSize()];
b_threadwise_transfer.Run(b_cyx_n_h_w_global_desc,
p_b_global,
b_cyx_n_h_w_thread_desc,
make_tuple(I0, I0, I0, I0),
p_b_thread,
b_cyx_n_h_w_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double); a_blockwise_copy.RunWrite(a_cyx_k_block_desc, p_a_block_double);
b_blockwise_copy.RunWrite(b_cyx_n_h_w_block_desc, p_b_block_double);
__syncthreads();
} }
#if 0
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
Float* p_a_block_even = p_a_block_double; Float* p_a_block_even = p_a_block_double;
...@@ -303,104 +290,82 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -303,104 +290,82 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
Float* p_a_block_odd = p_a_block_double + a_block_space_size; Float* p_a_block_odd = p_a_block_double + a_block_space_size;
Float* p_b_block_odd = p_b_block_double + b_block_space_size; Float* p_b_block_odd = p_b_block_double + b_block_space_size;
index_t k_block_data_begin = 0; index_t b_block_data_begin = 0;
// LDS double buffer: main body // LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow // use Do-While loop instead of For loop to simplify control flow
do do
{ {
// even iteration // even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_blockwise_copy.MoveSrcSliceWindow(a_cyx_k_global_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack); a_k_m_global_move_slice_window_iterator_hack);
// b_blockwise_copy.MoveSrcSliceWindow(b_cyx_n_h_w_global_desc,
// b_block_slice_copy_step,
// b_cyx_n_h_w_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_cyx_n_h_w_global_desc,
b_block_slice_copy_step,
b_cyx_n_h_w_global_move_slice_window_iterator_hack);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_cyx_k_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_cyx_n_h_w_global_desc, p_b_global, b_cyx_n_h_w_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
// blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread); blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd); a_blockwise_copy.RunWrite(a_cyx_k_block_desc, p_a_block_odd);
b_blockwise_copy.RunWrite(b_cyx_n_h_w_block_desc, p_b_block_odd);
// odd iteration // odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_blockwise_copy.MoveSrcSliceWindow(a_cyx_k_global_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack); a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_cyx_n_h_w_global_desc,
b_block_slice_copy_step,
b_cyx_n_h_w_global_move_slice_window_iterator_hack);
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_cyx_k_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_cyx_n_h_w_global_desc, p_b_global, b_cyx_n_h_w_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
// blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread); blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even); a_blockwise_copy.RunWrite(a_cyx_k_block_desc, p_a_block_even);
b_blockwise_copy.RunWrite(b_cyx_n_h_w_block_desc, p_b_block_even);
k_block_data_begin += 2 * KPerBlock; b_block_data_begin += 2 * CYXPerBlock;
} while(k_block_data_begin < K - 2 * KPerBlock); } while(b_block_data_begin < CYX - 2 * CYXPerBlock);
} }
// LDS double buffer: tail // LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_blockwise_copy.MoveSrcSliceWindow(a_cyx_k_global_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack); a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_cyx_n_h_w_global_desc,
b_block_slice_copy_step,
b_cyx_n_h_w_global_move_slice_window_iterator_hack);
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_blockwise_copy.RunRead(a_cyx_k_global_desc, p_a_global, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_cyx_n_h_w_global_desc, p_b_global, b_cyx_n_h_w_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
// blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size); a_blockwise_copy.RunWrite(a_cyx_k_block_desc, p_a_block_double + a_block_space_size);
b_blockwise_copy.RunWrite(b_cyx_n_h_w_block_desc, p_b_block_double + b_block_space_size);
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
// blockwise_gemm.Run(p_a_block_double + a_block_space_size, blockwise_gemm.Run(p_a_block_double + a_block_space_size,
// p_b_block_double + b_block_space_size, p_b_block_double + b_block_space_size,
// p_c_thread); p_c_thread);
} }
else // if has 1 iteration left else // if has 1 iteration left
{ {
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
// blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
} }
#endif
#if 1 #if 1
// output: register to global memory // output: register to global memory
...@@ -408,7 +373,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -408,7 +373,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// define input tensor descriptor for threadwise copy // define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy // thread input tensor, src of threadwise copy
constexpr auto c_k_n_h_w_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2( constexpr auto c_k_n_h_w_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MPerThread>{}, Number<1>{}, Number<1>{}, Number<1>{})); make_tuple(Number<KPerThread>{}, Number<1>{}, Number<1>{}, Number<1>{}));
// calculate origin of thread input tensor on global memory // calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
...@@ -432,15 +397,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -432,15 +397,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// hack to control index calculation when iterating over c_k_n_h_w_global tensor // hack to control index calculation when iterating over c_k_n_h_w_global tensor
constexpr auto c_k_n_h_w_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; constexpr auto c_k_n_h_w_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
// constexpr auto tmp = make_unmerge_transform(make_tuple(
// Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseDynamicTensorSliceTransfer_v1r3<
AccFloat, AccFloat,
Float, Float,
decltype(c_k_n_h_w_thread_desc), decltype(c_k_n_h_w_thread_desc),
decltype(c_k_n_h_w_global_desc), decltype(c_k_n_h_w_global_desc),
Sequence<MPerThread, 1, 1, 1>, Sequence<KPerThread, 1, 1, 1>,
Sequence<3, 2, 0, 1>, // CThreadTransferSrcDstAccessOrder Sequence<3, 2, 0, 1>, // CThreadTransferSrcDstAccessOrder
3, // CThreadTransferSrcDstVectorDim 3, // CThreadTransferSrcDstVectorDim
1, // CThreadTransferDstScalarPerVector, 1, // CThreadTransferDstScalarPerVector,
...@@ -464,7 +426,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -464,7 +426,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// pass tensor descriptor by reference // pass tensor descriptor by reference
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc, __device__ void Run(const AGlobalDesc& a_cyx_k_global_desc,
const Float* __restrict__ p_a_global, const Float* __restrict__ p_a_global,
const BGlobalDesc& b_cyx_n_h_w_global_desc, const BGlobalDesc& b_cyx_n_h_w_global_desc,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
...@@ -477,7 +439,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -477,7 +439,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
__shared__ Float p_shared_block[shared_block_size]; __shared__ Float p_shared_block[shared_block_size];
Run(a_k_m_global_desc, Run(a_cyx_k_global_desc,
p_a_global, p_a_global,
b_cyx_n_h_w_global_desc, b_cyx_n_h_w_global_desc,
p_b_global, p_b_global,
...@@ -490,7 +452,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -490,7 +452,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// pass tensor descriptors by their pointers // pass tensor descriptors by their pointers
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc* p_a_k_m_global_desc, __device__ void Run(const AGlobalDesc* p_a_cyx_k_global_desc,
const Float* __restrict__ p_a_global, const Float* __restrict__ p_a_global,
const BGlobalDesc* p_b_cyx_n_h_w_global_desc, const BGlobalDesc* p_b_cyx_n_h_w_global_desc,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
...@@ -499,11 +461,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -499,11 +461,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
const auto a_k_m_global_desc = *p_a_k_m_global_desc; const auto a_cyx_k_global_desc = *p_a_cyx_k_global_desc;
const auto b_cyx_n_h_w_global_desc = *p_b_cyx_n_h_w_global_desc; const auto b_cyx_n_h_w_global_desc = *p_b_cyx_n_h_w_global_desc;
const auto c_k_n_h_w_global_desc = *p_c_k_n_h_w_global_desc; const auto c_k_n_h_w_global_desc = *p_c_k_n_h_w_global_desc;
Run(a_k_m_global_desc, Run(a_cyx_k_global_desc,
p_a_global, p_a_global,
b_cyx_n_h_w_global_desc, b_cyx_n_h_w_global_desc,
p_b_global, p_b_global,
...@@ -515,7 +477,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -515,7 +477,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// pass tensor descriptors by void* // pass tensor descriptors by void*
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const void* p_a_k_m_global_desc, __device__ void Run(const void* p_a_cyx_k_global_desc,
const Float* __restrict__ p_a_global, const Float* __restrict__ p_a_global,
const void* p_b_cyx_n_h_w_global_desc, const void* p_b_cyx_n_h_w_global_desc,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
...@@ -524,12 +486,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -524,12 +486,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
const auto a_k_m_global_desc = *reinterpret_cast<const AGlobalDesc*>(p_a_k_m_global_desc); const auto a_cyx_k_global_desc =
const auto b_cyx_n_h_w_global_desc = *reinterpret_cast<const BGlobalDesc*>(p_b_cyx_n_h_w_global_desc); *reinterpret_cast<const AGlobalDesc*>(p_a_cyx_k_global_desc);
const auto b_cyx_n_h_w_global_desc =
*reinterpret_cast<const BGlobalDesc*>(p_b_cyx_n_h_w_global_desc);
const auto c_k_n_h_w_global_desc = const auto c_k_n_h_w_global_desc =
*reinterpret_cast<const CGlobalDesc*>(p_c_k_n_h_w_global_desc); *reinterpret_cast<const CGlobalDesc*>(p_c_k_n_h_w_global_desc);
Run(a_k_m_global_desc, Run(a_cyx_k_global_desc,
p_a_global, p_a_global,
b_cyx_n_h_w_global_desc, b_cyx_n_h_w_global_desc,
p_b_global, p_b_global,
......
...@@ -535,7 +535,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -535,7 +535,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx + dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
i * src_scalar_step_in_vector); i * src_scalar_step_in_vector);
p_dst[Number<dst_offset>{}] = src_vector[i]; // p_dst[Number<dst_offset>{}] = src_vector[i];
p_dst[Number<dst_offset>{}] = src_vector.Scalars()(i);
}); });
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
......
...@@ -28,33 +28,6 @@ __device__ void threadwise_matrix_set_zero_v3(Desc, Float* __restrict__ p_thread ...@@ -28,33 +28,6 @@ __device__ void threadwise_matrix_set_zero_v3(Desc, Float* __restrict__ p_thread
}); });
} }
template <typename SrcDesc,
typename DstDesc,
index_t NSliceRow,
index_t NSliceCol,
index_t DataPerAccess>
struct ThreadwiseMatrixSliceCopy_v3
{
template <typename Data>
__device__ static void Run(const Data* p_src, Data* p_dst)
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
using vector_t = typename vector_type<Data, DataPerAccess>::type;
static_for<0, NSliceRow, 1>{}([&](auto i) {
static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) {
constexpr auto src_offset = SrcDesc{}.CalculateOffset(make_tuple(i, j));
constexpr auto dst_offset = DstDesc{}.CalculateOffset(make_tuple(i, j));
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
});
});
}
};
// C[M, N] += transpose(A[K, M]) * B[K, N] // C[M, N] += transpose(A[K, M]) * B[K, N]
// Element of matrix can be vectorized data // Element of matrix can be vectorized data
template <typename ADesc, template <typename ADesc,
...@@ -75,9 +48,9 @@ struct ThreadwiseGemm_km_kn_mn_v3 ...@@ -75,9 +48,9 @@ struct ThreadwiseGemm_km_kn_mn_v3
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto M = CDesc{}[I0]; constexpr auto M = CDesc{}.GetLength(I0);
constexpr auto N = CDesc{}[I1]; constexpr auto N = CDesc{}.GetLength(I1);
constexpr auto K = ADesc{}[I0]; constexpr auto K = ADesc{}.GetLength(I0);
static_for<0, K, 1>{}([&](auto k) { static_for<0, K, 1>{}([&](auto k) {
static_for<0, M, 1>{}([&](auto m) { static_for<0, M, 1>{}([&](auto m) {
......
...@@ -76,7 +76,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc ...@@ -76,7 +76,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr index_t GemmMPerThread = 16; constexpr index_t GemmMPerThread = 16;
constexpr index_t GemmNPerThread = 1; constexpr index_t GemmNPerThread = 1;
constexpr index_t GemmKPerThread = 1; constexpr index_t GemmKPerThread = 4;
constexpr index_t GemmMLevel0Cluster = 1; constexpr index_t GemmMLevel0Cluster = 1;
constexpr index_t GemmNLevel0Cluster = 1; constexpr index_t GemmNLevel0Cluster = 1;
......
...@@ -779,7 +779,7 @@ int main(int argc, char* argv[]) ...@@ -779,7 +779,7 @@ int main(int argc, char* argv[])
#if 1 #if 1
// LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl; // LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
// LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl; // LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl; // LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl; LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
#endif #endif
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment