Commit 35d68cf8 authored by Chao Liu's avatar Chao Liu
Browse files

replacing array with vector for tensor data

parent 712babe4
......@@ -503,8 +503,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
}
__device__ void
Run_pipelined_2x2(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
template <typename CThreadBuffer>
__device__ void Run_pipelined_2x2(const FloatA* p_a_block,
const FloatB* p_b_block,
CThreadBuffer c_thread_buf) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -549,12 +551,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
auto a_thread_buf = make_dynamic_buffer<FloatA>(p_a_thread);
auto b_thread_buf = make_dynamic_buffer<FloatB>(p_b_thread);
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<FloatA,
FloatB,
FloatC,
decltype(a_thread_sub_mtx),
decltype(b_thread_sub_mtx),
decltype(c_thread_sub_mtx)>{};
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1r1<FloatA,
FloatB,
FloatC,
decltype(a_thread_sub_mtx),
decltype(b_thread_sub_mtx),
decltype(c_thread_sub_mtx)>{};
// read A_sub_0
a_thread_copy_.Run(BlockMatrixA{},
......@@ -589,13 +591,20 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<0>{}),
b_thread_buf,
make_tuple(Number<0>{}, Number<0>{}),
c_thread_buf,
make_tuple(Number<0>{}, Number<0>{}));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(
p_a_thread,
p_b_thread + b_thread_mtx_desc_.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx_desc.CalculateOffset(make_tuple(0, NPerThreadSubC)));
threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<0>{}),
b_thread_buf,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
c_thread_buf,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}));
// loop over rest of k
static_for<KPerThreadLoop, K, KPerThreadLoop>{}([&](auto k) {
......@@ -608,10 +617,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
a_thread_buf);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(
p_a_thread + a_thread_mtx_desc_.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread,
p_c_thread + c_thread_mtx_desc.CalculateOffset(make_tuple(MPerThreadSubC, 0)));
threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
b_thread_buf,
make_tuple(Number<0>{}, Number<0>{}),
c_thread_buf,
make_tuple(Number<MPerThreadSubC>{}, Number<0>{}));
// read B_sub_0
b_thread_copy_.Run(BlockMatrixB{},
......@@ -622,11 +633,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
b_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(
p_a_thread + a_thread_mtx_desc_.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread + b_thread_mtx_desc_.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread +
c_thread_mtx_desc.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC)));
threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
b_thread_buf,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
c_thread_buf,
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}));
// read B_sub_1
b_thread_copy_.Run(BlockMatrixB{},
......@@ -645,30 +657,42 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<0>{}),
b_thread_buf,
make_tuple(Number<0>{}, Number<0>{}),
c_thread_buf,
make_tuple(Number<0>{}, Number<0>{}));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(
p_a_thread,
p_b_thread + b_thread_mtx_desc_.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx_desc.CalculateOffset(make_tuple(0, NPerThreadSubC)));
threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<0>{}),
b_thread_buf,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
c_thread_buf,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}));
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(
p_a_thread + a_thread_mtx_desc_.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread,
p_c_thread + c_thread_mtx_desc.CalculateOffset(make_tuple(MPerThreadSubC, 0)));
threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
b_thread_buf,
make_tuple(Number<0>{}, Number<0>{}),
c_thread_buf,
make_tuple(Number<MPerThreadSubC>{}, Number<0>{}));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(
p_a_thread + a_thread_mtx_desc_.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread + b_thread_mtx_desc_.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread +
c_thread_mtx_desc.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC)));
threadwise_gemm.Run(a_thread_buf,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
b_thread_buf,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
c_thread_buf,
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}));
}
__device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
template <typename CThreadBuffer>
__device__ void
Run(const FloatA* p_a_block, const FloatB* p_b_block, CThreadBuffer c_thread_buf) const
{
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
constexpr auto I0 = Number<0>{};
......@@ -682,14 +706,14 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
if constexpr(MRepeat == 2 && NRepeat == 2)
{
Run_pipelined_2x2(p_a_block, p_b_block, p_c_thread);
Run_pipelined_2x2(p_a_block, p_b_block, c_thread_buf);
}
else
{
Run_naive(p_a_block, p_b_block, p_c_thread);
Run_naive(p_a_block, p_b_block, c_thread_buf);
}
#else
Run_naive(p_a_block, p_b_block, p_c_thread);
Run_naive(p_a_block, p_b_block, c_thread_buf);
#endif
}
};
......
......@@ -732,6 +732,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
// register allocation for output
FloatAcc p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpaceSize()];
auto c_thread_buf = make_dynamic_buffer<FloatAcc>(p_c_thread);
// zero out threadwise output
threadwise_matrix_set_zero_v2(c_m0m1_n0n1_thread_desc, p_c_thread);
......@@ -789,7 +791,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread);
blockwise_gemm.Run(p_a_block_even, p_b_block_even, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd);
......@@ -812,7 +814,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread);
blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even);
......@@ -839,7 +841,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
blockwise_gemm.Run(p_a_block_double, p_b_block_double, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size);
......@@ -850,14 +852,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double + a_block_space_size,
p_b_block_double + b_block_space_size,
p_c_thread);
c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
blockwise_gemm.Run(p_a_block_double, p_b_block_double, c_thread_buf);
}
// output: register to global memory
......
......@@ -1370,7 +1370,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
const SrcData* p_src,
const DstDesc&,
const DstRefToOriginDisplacement&,
DstBuffer dst_buf) const
DstBuffer& dst_buf) const
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time");
......
......@@ -168,5 +168,186 @@ struct ThreadwiseGemm_km_kn_mn_v1
}
};
// C[M, N] += transpose(A[K, M]) * B[K, N]
// Element of matrix can be vectorized data
// Assume:
// 1. ADesc, BDesc, CDesc are known at compile-time
// 2. ABuffer, BBuffer, CBuffer are static buffer
// 3. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template <typename FloatA,
typename FloatB,
typename FloatC,
typename ADesc,
typename BDesc,
typename CDesc,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseGemm_km_kn_mn_v1r1
{
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run_source(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto M = CDesc{}.GetLength(I0);
constexpr auto N = CDesc{}.GetLength(I1);
constexpr auto K = ADesc{}.GetLength(I0);
constexpr auto a_origin_idx = AOriginIdx{};
constexpr auto b_origin_idx = BOriginIdx{};
constexpr auto c_origin_idx = COriginIdx{};
static_for<0, K, 1>{}([&](auto k) {
static_for<0, M, 1>{}([&](auto m) {
static_for<0, N, 1>{}([&](auto n) {
constexpr auto a_offset =
ADesc{}.CalculateOffset(a_origin_idx + make_tuple(k, m));
constexpr auto b_offset =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, n));
constexpr auto c_offset =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, n));
c_buf.template AsType<FloatC>()(c_offset) +=
inner_product_with_conversion<FloatC>{}(
a_buf.template AsType<FloatA>()[a_offset],
b_buf.template AsType<FloatB>()[b_offset]);
});
});
});
}
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run_amd_asm(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto M = CDesc{}.GetLength(I0);
constexpr auto N = CDesc{}.GetLength(I1);
constexpr auto K = ADesc{}.GetLength(I0);
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_assert(N == 4 || N == 2, "wrong! this config not supported by asm yet");
static_for<0, K, 1>{}([&](auto k) {
static_for<0, M, 1>{}([&](auto m) {
constexpr auto a_offset = ADesc{}.CalculateOffset(a_origin_idx + make_tuple(k, m));
if constexpr(N == 2)
{
constexpr auto b_offset_0 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0));
constexpr auto b_offset_1 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1));
constexpr auto c_offset_0 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0));
constexpr auto c_offset_1 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1));
amd_assembly_outer_product_1x2(a_buf.template AsType<FloatA>()[a_offset],
b_buf.template AsType<FloatB>()[b_offset_0],
b_buf.template AsType<FloatB>()[b_offset_1],
c_buf.template AsType<FloatC>()(c_offset_0),
c_buf.template AsType<FloatC>()(c_offset_1));
}
else if constexpr(N == 4)
{
constexpr auto b_offset_0 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0));
constexpr auto b_offset_1 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1));
constexpr auto b_offset_2 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I2));
constexpr auto b_offset_3 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I3));
constexpr auto c_offset_0 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0));
constexpr auto c_offset_1 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1));
constexpr auto c_offset_2 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I2));
constexpr auto c_offset_3 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I3));
amd_assembly_outer_product_1x4(a_buf.template AsType<FloatA>()[a_offset],
b_buf.template AsType<FloatB>()[b_offset_0],
b_buf.template AsType<FloatB>()[b_offset_1],
b_buf.template AsType<FloatB>()[b_offset_2],
b_buf.template AsType<FloatB>()[b_offset_3],
c_buf.template AsType<FloatC>()(c_offset_0),
c_buf.template AsType<FloatC>()(c_offset_1),
c_buf.template AsType<FloatC>()(c_offset_2),
c_buf.template AsType<FloatC>()(c_offset_3));
}
});
});
}
#endif
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
Run_amd_asm(a_buf, AOriginIdx{}, b_buf, BOriginIdx{}, c_buf, COriginIdx{});
#else
Run_source(a_buf, AOriginIdx{}, b_buf, BOriginIdx{}, c_buf, COriginIdx{});
#endif
}
};
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment