Commit 437c996a authored by Chao Liu's avatar Chao Liu
Browse files

use StaticBuffer for thread matrix A/B in blockwise GEMM

parent 36de63ff
...@@ -545,11 +545,8 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -545,11 +545,8 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}), make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}),
make_tuple(Number<NPerThread>{}, Number<1>{})); make_tuple(Number<NPerThread>{}, Number<1>{}));
FloatA p_a_thread[a_thread_mtx_desc_.GetElementSpaceSize()]; auto a_thread_buf = make_static_buffer<FloatA>(a_thread_mtx_desc_.GetElementSpaceSize());
FloatB p_b_thread[b_thread_mtx_desc_.GetElementSpaceSize()]; auto b_thread_buf = make_static_buffer<FloatB>(b_thread_mtx_desc_.GetElementSpaceSize());
auto a_thread_buf = make_dynamic_buffer(p_a_thread);
auto b_thread_buf = make_dynamic_buffer(p_b_thread);
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1r1<FloatA, constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1r1<FloatA,
FloatB, FloatB,
......
...@@ -1379,9 +1379,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1379,9 +1379,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time"); "wrong! SrcDesc and DstDesc need to known at compile-time");
#if 0 // debug
static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
#endif
static_assert(is_known_at_compile_time< static_assert(is_known_at_compile_time<
remove_cv_t<remove_reference_t<SrcRefToOriginDisplacement>>>::value && remove_cv_t<remove_reference_t<SrcRefToOriginDisplacement>>>::value &&
...@@ -1437,13 +1435,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1437,13 +1435,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
container_reorder_given_new2old(access_lengths, dim_access_order); container_reorder_given_new2old(access_lengths, dim_access_order);
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) { static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// position in slice window #if 0
#if 0 // debug // TODO: unable to compile
// TODO: unable to compile // position in slice window
constexpr auto data_to_origin_disp_idx = constexpr auto data_to_origin_disp_idx =
container_reorder_given_old2new(ordered_access_idx, dim_access_order) * container_reorder_given_old2new(ordered_access_idx, dim_access_order) *
src_scalar_per_access; src_scalar_per_access;
#else #else
// position in slice window
constexpr auto data_to_origin_disp_idx = constexpr auto data_to_origin_disp_idx =
ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access; ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access;
#endif #endif
...@@ -1470,13 +1469,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1470,13 +1469,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
src_desc, src_data_coord); src_desc, src_data_coord);
#if 0 #if 0
// TODO: this is slooooooooow! // TODO: this is slooooooooow due to VGPR over-allocation
src_tmp_buf.template AsType<src_vector_t>()(Number<0>{}) = src_tmp_buf.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid ? src_buf.template AsType<src_vector_t>()[src_data_coord.GetOffset() / is_src_valid ? src_buf.template AsType<src_vector_t>()[src_data_coord.GetOffset() /
SrcScalarPerVector] SrcScalarPerVector]
: src_vector_t{0}; : src_vector_t{0};
#else #else
// this has normal performance but it's hacky // TODO: this is workaround. this has normal performance but it's hacky
src_tmp_buf.template AsType<src_vector_t>()(Number<0>{}) = src_tmp_buf.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid is_src_valid
? *reinterpret_cast<const src_vector_t*>(&(reinterpret_cast<const SrcData*>( ? *reinterpret_cast<const src_vector_t*>(&(reinterpret_cast<const SrcData*>(
......
...@@ -191,60 +191,12 @@ struct ThreadwiseGemm_km_kn_mn_v1r1 ...@@ -191,60 +191,12 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
typename BOriginIdx, typename BOriginIdx,
typename CBuffer, typename CBuffer,
typename COriginIdx> typename COriginIdx>
__device__ static void Run_source(const ABuffer& a_buf, __device__ static void Run(const ABuffer& a_buf,
AOriginIdx, AOriginIdx,
const BBuffer& b_buf, const BBuffer& b_buf,
BOriginIdx, BOriginIdx,
CBuffer& c_buf, CBuffer& c_buf,
COriginIdx) COriginIdx)
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto M = CDesc{}.GetLength(I0);
constexpr auto N = CDesc{}.GetLength(I1);
constexpr auto K = ADesc{}.GetLength(I0);
constexpr auto a_origin_idx = AOriginIdx{};
constexpr auto b_origin_idx = BOriginIdx{};
constexpr auto c_origin_idx = COriginIdx{};
static_for<0, K, 1>{}([&](auto k) {
static_for<0, M, 1>{}([&](auto m) {
static_for<0, N, 1>{}([&](auto n) {
constexpr auto a_offset =
ADesc{}.CalculateOffset(a_origin_idx + make_tuple(k, m));
constexpr auto b_offset =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, n));
constexpr auto c_offset =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, n));
c_buf.template AsType<FloatC>()(c_offset) +=
inner_product_with_conversion<FloatC>{}(
a_buf.template AsType<FloatA>()[a_offset],
b_buf.template AsType<FloatB>()[b_offset]);
});
});
});
}
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run_amd_asm(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{ {
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(), CDesc::IsKnownAtCompileTime(),
...@@ -258,8 +210,6 @@ struct ThreadwiseGemm_km_kn_mn_v1r1 ...@@ -258,8 +210,6 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto M = CDesc{}.GetLength(I0); constexpr auto M = CDesc{}.GetLength(I0);
constexpr auto N = CDesc{}.GetLength(I1); constexpr auto N = CDesc{}.GetLength(I1);
...@@ -269,83 +219,30 @@ struct ThreadwiseGemm_km_kn_mn_v1r1 ...@@ -269,83 +219,30 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_assert(N == 4 || N == 2, "wrong! this config not supported by asm yet");
static_for<0, K, 1>{}([&](auto k) { static_for<0, K, 1>{}([&](auto k) {
static_for<0, M, 1>{}([&](auto m) { static_for<0, M, 1>{}([&](auto m) {
constexpr auto a_offset = ADesc{}.CalculateOffset(a_origin_idx + make_tuple(k, m)); static_for<0, N, 1>{}([&](auto n) {
if constexpr(N == 2) constexpr index_t a_offset =
{ ADesc{}.CalculateOffset(a_origin_idx + make_tuple(k, m));
constexpr auto b_offset_0 = constexpr index_t b_offset =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0)); BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, n));
constexpr auto b_offset_1 = constexpr index_t c_offset =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1)); CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, n));
constexpr auto c_offset_0 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0));
constexpr auto c_offset_1 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1));
amd_assembly_outer_product_1x2(a_buf.template AsType<FloatA>()[a_offset],
b_buf.template AsType<FloatB>()[b_offset_0],
b_buf.template AsType<FloatB>()[b_offset_1],
c_buf.template AsType<FloatC>()(c_offset_0),
c_buf.template AsType<FloatC>()(c_offset_1));
}
else if constexpr(N == 4)
{
constexpr auto b_offset_0 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0));
constexpr auto b_offset_1 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1));
constexpr auto b_offset_2 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I2));
constexpr auto b_offset_3 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I3));
constexpr auto c_offset_0 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0));
constexpr auto c_offset_1 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1));
constexpr auto c_offset_2 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I2));
constexpr auto c_offset_3 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I3));
amd_assembly_outer_product_1x4(a_buf.template AsType<FloatA>()[a_offset],
b_buf.template AsType<FloatB>()[b_offset_0],
b_buf.template AsType<FloatB>()[b_offset_1],
b_buf.template AsType<FloatB>()[b_offset_2],
b_buf.template AsType<FloatB>()[b_offset_3],
c_buf.template AsType<FloatC>()(c_offset_0),
c_buf.template AsType<FloatC>()(c_offset_1),
c_buf.template AsType<FloatC>()(c_offset_2),
c_buf.template AsType<FloatC>()(c_offset_3));
}
});
});
}
#endif
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM #if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
Run_amd_asm(a_buf, AOriginIdx{}, b_buf, BOriginIdx{}, c_buf, COriginIdx{}); amd_assembly_inner_product(a_buf.template AsType<FloatA>()[Number<a_offset>{}],
b_buf.template AsType<FloatB>()[Number<b_offset>{}],
c_buf.template AsType<FloatC>()(Number<c_offset>{}));
#else #else
Run_source(a_buf, AOriginIdx{}, b_buf, BOriginIdx{}, c_buf, COriginIdx{}); c_buf.template AsType<FloatC>()(Number<c_offset>{}) +=
inner_product_with_conversion<FloatC>{}(
a_buf.template AsType<FloatA>()[Number<a_offset>{}],
b_buf.template AsType<FloatB>()[Number<b_offset>{}]);
#endif #endif
});
});
});
} }
}; };
......
...@@ -5,6 +5,24 @@ ...@@ -5,6 +5,24 @@
namespace ck { namespace ck {
// c += inner_product(a, b)
__device__ void amd_assembly_inner_product(const float& a, const float& b, float& c)
{
#if CK_USE_AMD_V_FMAC_F32
asm volatile("\n \
v_fmac_f32 %0, %1, %2 \n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
#else
asm volatile("\n \
v_mac_f32 %0, %1, %2 \n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
#endif
}
// c0 += inner_product(a, b0) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1) // c1 += inner_product(a, b1)
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment