"vscode:/vscode.git/clone" did not exist on "29496c95d3d04eafae5eb9d0de2b3e4673df3a73"
Commit 35d68cf8 authored by Chao Liu's avatar Chao Liu
Browse files

replacing array with vector for tensor data

parent 712babe4
...@@ -503,8 +503,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -503,8 +503,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC}; level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
} }
__device__ void template <typename CThreadBuffer>
Run_pipelined_2x2(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const __device__ void Run_pipelined_2x2(const FloatA* p_a_block,
const FloatB* p_b_block,
CThreadBuffer c_thread_buf) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -549,12 +551,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -549,12 +551,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
auto a_thread_buf = make_dynamic_buffer<FloatA>(p_a_thread); auto a_thread_buf = make_dynamic_buffer<FloatA>(p_a_thread);
auto b_thread_buf = make_dynamic_buffer<FloatB>(p_b_thread); auto b_thread_buf = make_dynamic_buffer<FloatB>(p_b_thread);
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<FloatA, constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1r1<FloatA,
FloatB, FloatB,
FloatC, FloatC,
decltype(a_thread_sub_mtx), decltype(a_thread_sub_mtx),
decltype(b_thread_sub_mtx), decltype(b_thread_sub_mtx),
decltype(c_thread_sub_mtx)>{}; decltype(c_thread_sub_mtx)>{};
// read A_sub_0 // read A_sub_0
a_thread_copy_.Run(BlockMatrixA{}, a_thread_copy_.Run(BlockMatrixA{},
...@@ -589,13 +591,20 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -589,13 +591,20 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
a_thread_buf); a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0 // C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread); threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<0>{}),
b_thread_buf,
make_tuple(Number<0>{}, Number<0>{}),
c_thread_buf,
make_tuple(Number<0>{}, Number<0>{}));
// C_sub_01 += transpose(A_sub_0) * B_sub_1 // C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run( threadwise_gemm.Run(a_thread_buf,
p_a_thread, make_tuple(Number<0>{}, Number<0>{}),
p_b_thread + b_thread_mtx_desc_.CalculateOffset(make_tuple(0, NPerThreadSubC)), b_thread_buf,
p_c_thread + c_thread_mtx_desc.CalculateOffset(make_tuple(0, NPerThreadSubC))); make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
c_thread_buf,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}));
// loop over rest of k // loop over rest of k
static_for<KPerThreadLoop, K, KPerThreadLoop>{}([&](auto k) { static_for<KPerThreadLoop, K, KPerThreadLoop>{}([&](auto k) {
...@@ -608,10 +617,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -608,10 +617,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
a_thread_buf); a_thread_buf);
// C_sub_10 += transpose(A_sub_1) * B_sub_0 // C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run( threadwise_gemm.Run(a_thread_buf,
p_a_thread + a_thread_mtx_desc_.CalculateOffset(make_tuple(0, MPerThreadSubC)), make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
p_b_thread, b_thread_buf,
p_c_thread + c_thread_mtx_desc.CalculateOffset(make_tuple(MPerThreadSubC, 0))); make_tuple(Number<0>{}, Number<0>{}),
c_thread_buf,
make_tuple(Number<MPerThreadSubC>{}, Number<0>{}));
// read B_sub_0 // read B_sub_0
b_thread_copy_.Run(BlockMatrixB{}, b_thread_copy_.Run(BlockMatrixB{},
...@@ -622,11 +633,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -622,11 +633,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
b_thread_buf); b_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1 // C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run( threadwise_gemm.Run(a_thread_buf,
p_a_thread + a_thread_mtx_desc_.CalculateOffset(make_tuple(0, MPerThreadSubC)), make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
p_b_thread + b_thread_mtx_desc_.CalculateOffset(make_tuple(0, NPerThreadSubC)), b_thread_buf,
p_c_thread + make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
c_thread_mtx_desc.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC))); c_thread_buf,
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}));
// read B_sub_1 // read B_sub_1
b_thread_copy_.Run(BlockMatrixB{}, b_thread_copy_.Run(BlockMatrixB{},
...@@ -645,30 +657,42 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -645,30 +657,42 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
a_thread_buf); a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0 // C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread); threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<0>{}),
b_thread_buf,
make_tuple(Number<0>{}, Number<0>{}),
c_thread_buf,
make_tuple(Number<0>{}, Number<0>{}));
// C_sub_01 += transpose(A_sub_0) * B_sub_1 // C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run( threadwise_gemm.Run(a_thread_buf,
p_a_thread, make_tuple(Number<0>{}, Number<0>{}),
p_b_thread + b_thread_mtx_desc_.CalculateOffset(make_tuple(0, NPerThreadSubC)), b_thread_buf,
p_c_thread + c_thread_mtx_desc.CalculateOffset(make_tuple(0, NPerThreadSubC))); make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
c_thread_buf,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}));
}); });
// C_sub_10 += transpose(A_sub_1) * B_sub_0 // C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run( threadwise_gemm.Run(a_thread_buf,
p_a_thread + a_thread_mtx_desc_.CalculateOffset(make_tuple(0, MPerThreadSubC)), make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
p_b_thread, b_thread_buf,
p_c_thread + c_thread_mtx_desc.CalculateOffset(make_tuple(MPerThreadSubC, 0))); make_tuple(Number<0>{}, Number<0>{}),
c_thread_buf,
make_tuple(Number<MPerThreadSubC>{}, Number<0>{}));
// C_sub_11 += transpose(A_sub_1) * B_sub_1 // C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run( threadwise_gemm.Run(a_thread_buf,
p_a_thread + a_thread_mtx_desc_.CalculateOffset(make_tuple(0, MPerThreadSubC)), make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
p_b_thread + b_thread_mtx_desc_.CalculateOffset(make_tuple(0, NPerThreadSubC)), b_thread_buf,
p_c_thread + make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
c_thread_mtx_desc.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC))); c_thread_buf,
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}));
} }
__device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const template <typename CThreadBuffer>
__device__ void
Run(const FloatA* p_a_block, const FloatB* p_b_block, CThreadBuffer c_thread_buf) const
{ {
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE #if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -682,14 +706,14 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -682,14 +706,14 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
if constexpr(MRepeat == 2 && NRepeat == 2) if constexpr(MRepeat == 2 && NRepeat == 2)
{ {
Run_pipelined_2x2(p_a_block, p_b_block, p_c_thread); Run_pipelined_2x2(p_a_block, p_b_block, c_thread_buf);
} }
else else
{ {
Run_naive(p_a_block, p_b_block, p_c_thread); Run_naive(p_a_block, p_b_block, c_thread_buf);
} }
#else #else
Run_naive(p_a_block, p_b_block, p_c_thread); Run_naive(p_a_block, p_b_block, c_thread_buf);
#endif #endif
} }
}; };
......
...@@ -732,6 +732,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -732,6 +732,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
// register allocation for output // register allocation for output
FloatAcc p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpaceSize()]; FloatAcc p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpaceSize()];
auto c_thread_buf = make_dynamic_buffer<FloatAcc>(p_c_thread);
// zero out threadwise output // zero out threadwise output
threadwise_matrix_set_zero_v2(c_m0m1_n0n1_thread_desc, p_c_thread); threadwise_matrix_set_zero_v2(c_m0m1_n0n1_thread_desc, p_c_thread);
...@@ -789,7 +791,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -789,7 +791,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread); blockwise_gemm.Run(p_a_block_even, p_b_block_even, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd); a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd);
...@@ -812,7 +814,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -812,7 +814,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread); blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even); a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even);
...@@ -839,7 +841,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -839,7 +841,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); blockwise_gemm.Run(p_a_block_double, p_b_block_double, c_thread_buf);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size); a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size);
...@@ -850,14 +852,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -850,14 +852,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double + a_block_space_size, blockwise_gemm.Run(p_a_block_double + a_block_space_size,
p_b_block_double + b_block_space_size, p_b_block_double + b_block_space_size,
p_c_thread); c_thread_buf);
} }
else // if has 1 iteration left else // if has 1 iteration left
{ {
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); blockwise_gemm.Run(p_a_block_double, p_b_block_double, c_thread_buf);
} }
// output: register to global memory // output: register to global memory
......
...@@ -1370,7 +1370,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1370,7 +1370,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
const SrcData* p_src, const SrcData* p_src,
const DstDesc&, const DstDesc&,
const DstRefToOriginDisplacement&, const DstRefToOriginDisplacement&,
DstBuffer dst_buf) const DstBuffer& dst_buf) const
{ {
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time"); "wrong! SrcDesc and DstDesc need to known at compile-time");
......
...@@ -168,5 +168,186 @@ struct ThreadwiseGemm_km_kn_mn_v1 ...@@ -168,5 +168,186 @@ struct ThreadwiseGemm_km_kn_mn_v1
} }
}; };
// C[M, N] += transpose(A[K, M]) * B[K, N]
// Element of matrix can be vectorized data
// Assume:
// 1. ADesc, BDesc, CDesc are known at compile-time
// 2. ABuffer, BBuffer, CBuffer are static buffer
// 3. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template <typename FloatA,
typename FloatB,
typename FloatC,
typename ADesc,
typename BDesc,
typename CDesc,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseGemm_km_kn_mn_v1r1
{
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run_source(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto M = CDesc{}.GetLength(I0);
constexpr auto N = CDesc{}.GetLength(I1);
constexpr auto K = ADesc{}.GetLength(I0);
constexpr auto a_origin_idx = AOriginIdx{};
constexpr auto b_origin_idx = BOriginIdx{};
constexpr auto c_origin_idx = COriginIdx{};
static_for<0, K, 1>{}([&](auto k) {
static_for<0, M, 1>{}([&](auto m) {
static_for<0, N, 1>{}([&](auto n) {
constexpr auto a_offset =
ADesc{}.CalculateOffset(a_origin_idx + make_tuple(k, m));
constexpr auto b_offset =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, n));
constexpr auto c_offset =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, n));
c_buf.template AsType<FloatC>()(c_offset) +=
inner_product_with_conversion<FloatC>{}(
a_buf.template AsType<FloatA>()[a_offset],
b_buf.template AsType<FloatB>()[b_offset]);
});
});
});
}
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run_amd_asm(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto M = CDesc{}.GetLength(I0);
constexpr auto N = CDesc{}.GetLength(I1);
constexpr auto K = ADesc{}.GetLength(I0);
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_assert(N == 4 || N == 2, "wrong! this config not supported by asm yet");
static_for<0, K, 1>{}([&](auto k) {
static_for<0, M, 1>{}([&](auto m) {
constexpr auto a_offset = ADesc{}.CalculateOffset(a_origin_idx + make_tuple(k, m));
if constexpr(N == 2)
{
constexpr auto b_offset_0 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0));
constexpr auto b_offset_1 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1));
constexpr auto c_offset_0 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0));
constexpr auto c_offset_1 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1));
amd_assembly_outer_product_1x2(a_buf.template AsType<FloatA>()[a_offset],
b_buf.template AsType<FloatB>()[b_offset_0],
b_buf.template AsType<FloatB>()[b_offset_1],
c_buf.template AsType<FloatC>()(c_offset_0),
c_buf.template AsType<FloatC>()(c_offset_1));
}
else if constexpr(N == 4)
{
constexpr auto b_offset_0 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0));
constexpr auto b_offset_1 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1));
constexpr auto b_offset_2 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I2));
constexpr auto b_offset_3 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I3));
constexpr auto c_offset_0 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0));
constexpr auto c_offset_1 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1));
constexpr auto c_offset_2 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I2));
constexpr auto c_offset_3 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I3));
amd_assembly_outer_product_1x4(a_buf.template AsType<FloatA>()[a_offset],
b_buf.template AsType<FloatB>()[b_offset_0],
b_buf.template AsType<FloatB>()[b_offset_1],
b_buf.template AsType<FloatB>()[b_offset_2],
b_buf.template AsType<FloatB>()[b_offset_3],
c_buf.template AsType<FloatC>()(c_offset_0),
c_buf.template AsType<FloatC>()(c_offset_1),
c_buf.template AsType<FloatC>()(c_offset_2),
c_buf.template AsType<FloatC>()(c_offset_3));
}
});
});
}
#endif
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
Run_amd_asm(a_buf, AOriginIdx{}, b_buf, BOriginIdx{}, c_buf, COriginIdx{});
#else
Run_source(a_buf, AOriginIdx{}, b_buf, BOriginIdx{}, c_buf, COriginIdx{});
#endif
}
};
} // namespace ck } // namespace ck
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment