Commit 437c996a authored by Chao Liu's avatar Chao Liu
Browse files

use StaticBuffer for thread matrix A/B in blockwise GEMM

parent 36de63ff
......@@ -545,11 +545,8 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}),
make_tuple(Number<NPerThread>{}, Number<1>{}));
FloatA p_a_thread[a_thread_mtx_desc_.GetElementSpaceSize()];
FloatB p_b_thread[b_thread_mtx_desc_.GetElementSpaceSize()];
auto a_thread_buf = make_dynamic_buffer(p_a_thread);
auto b_thread_buf = make_dynamic_buffer(p_b_thread);
auto a_thread_buf = make_static_buffer<FloatA>(a_thread_mtx_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<FloatB>(b_thread_mtx_desc_.GetElementSpaceSize());
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1r1<FloatA,
FloatB,
......
......@@ -1379,9 +1379,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time");
#if 0 // debug
static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
#endif
static_assert(is_known_at_compile_time<
remove_cv_t<remove_reference_t<SrcRefToOriginDisplacement>>>::value &&
......@@ -1437,13 +1435,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
container_reorder_given_new2old(access_lengths, dim_access_order);
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// position in slice window
#if 0 // debug
// TODO: unable to compile
#if 0
// TODO: unable to compile
// position in slice window
constexpr auto data_to_origin_disp_idx =
container_reorder_given_old2new(ordered_access_idx, dim_access_order) *
src_scalar_per_access;
#else
// position in slice window
constexpr auto data_to_origin_disp_idx =
ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access;
#endif
......@@ -1470,13 +1469,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
src_desc, src_data_coord);
#if 0
// TODO: this is slooooooooow!
// TODO: this is slooooooooow due to VGPR over-allocation
src_tmp_buf.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid ? src_buf.template AsType<src_vector_t>()[src_data_coord.GetOffset() /
SrcScalarPerVector]
: src_vector_t{0};
#else
// this has normal performance but it's hacky
// TODO: this is workaround. this has normal performance but it's hacky
src_tmp_buf.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid
? *reinterpret_cast<const src_vector_t*>(&(reinterpret_cast<const SrcData*>(
......
......@@ -191,60 +191,12 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run_source(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto M = CDesc{}.GetLength(I0);
constexpr auto N = CDesc{}.GetLength(I1);
constexpr auto K = ADesc{}.GetLength(I0);
constexpr auto a_origin_idx = AOriginIdx{};
constexpr auto b_origin_idx = BOriginIdx{};
constexpr auto c_origin_idx = COriginIdx{};
static_for<0, K, 1>{}([&](auto k) {
static_for<0, M, 1>{}([&](auto m) {
static_for<0, N, 1>{}([&](auto n) {
constexpr auto a_offset =
ADesc{}.CalculateOffset(a_origin_idx + make_tuple(k, m));
constexpr auto b_offset =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, n));
constexpr auto c_offset =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, n));
c_buf.template AsType<FloatC>()(c_offset) +=
inner_product_with_conversion<FloatC>{}(
a_buf.template AsType<FloatA>()[a_offset],
b_buf.template AsType<FloatB>()[b_offset]);
});
});
});
}
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run_amd_asm(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
__device__ static void Run(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
......@@ -258,8 +210,6 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto M = CDesc{}.GetLength(I0);
constexpr auto N = CDesc{}.GetLength(I1);
......@@ -269,83 +219,30 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_assert(N == 4 || N == 2, "wrong! this config not supported by asm yet");
static_for<0, K, 1>{}([&](auto k) {
static_for<0, M, 1>{}([&](auto m) {
constexpr auto a_offset = ADesc{}.CalculateOffset(a_origin_idx + make_tuple(k, m));
static_for<0, N, 1>{}([&](auto n) {
if constexpr(N == 2)
{
constexpr auto b_offset_0 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0));
constexpr auto b_offset_1 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1));
constexpr auto c_offset_0 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0));
constexpr auto c_offset_1 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1));
amd_assembly_outer_product_1x2(a_buf.template AsType<FloatA>()[a_offset],
b_buf.template AsType<FloatB>()[b_offset_0],
b_buf.template AsType<FloatB>()[b_offset_1],
c_buf.template AsType<FloatC>()(c_offset_0),
c_buf.template AsType<FloatC>()(c_offset_1));
}
else if constexpr(N == 4)
{
constexpr auto b_offset_0 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0));
constexpr auto b_offset_1 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1));
constexpr auto b_offset_2 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I2));
constexpr auto b_offset_3 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I3));
constexpr auto c_offset_0 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0));
constexpr auto c_offset_1 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1));
constexpr auto c_offset_2 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I2));
constexpr auto c_offset_3 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I3));
amd_assembly_outer_product_1x4(a_buf.template AsType<FloatA>()[a_offset],
b_buf.template AsType<FloatB>()[b_offset_0],
b_buf.template AsType<FloatB>()[b_offset_1],
b_buf.template AsType<FloatB>()[b_offset_2],
b_buf.template AsType<FloatB>()[b_offset_3],
c_buf.template AsType<FloatC>()(c_offset_0),
c_buf.template AsType<FloatC>()(c_offset_1),
c_buf.template AsType<FloatC>()(c_offset_2),
c_buf.template AsType<FloatC>()(c_offset_3));
}
});
});
}
#endif
constexpr index_t a_offset =
ADesc{}.CalculateOffset(a_origin_idx + make_tuple(k, m));
constexpr index_t b_offset =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, n));
constexpr index_t c_offset =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, n));
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
Run_amd_asm(a_buf, AOriginIdx{}, b_buf, BOriginIdx{}, c_buf, COriginIdx{});
amd_assembly_inner_product(a_buf.template AsType<FloatA>()[Number<a_offset>{}],
b_buf.template AsType<FloatB>()[Number<b_offset>{}],
c_buf.template AsType<FloatC>()(Number<c_offset>{}));
#else
Run_source(a_buf, AOriginIdx{}, b_buf, BOriginIdx{}, c_buf, COriginIdx{});
c_buf.template AsType<FloatC>()(Number<c_offset>{}) +=
inner_product_with_conversion<FloatC>{}(
a_buf.template AsType<FloatA>()[Number<a_offset>{}],
b_buf.template AsType<FloatB>()[Number<b_offset>{}]);
#endif
});
});
});
}
};
......
......@@ -5,6 +5,24 @@
namespace ck {
// c += inner_product(a, b)
__device__ void amd_assembly_inner_product(const float& a, const float& b, float& c)
{
#if CK_USE_AMD_V_FMAC_F32
asm volatile("\n \
v_fmac_f32 %0, %1, %2 \n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
#else
asm volatile("\n \
v_mac_f32 %0, %1, %2 \n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
#endif
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment