Commit b6e43b25 authored by Chao Liu's avatar Chao Liu
Browse files

bug fix

parent f5654649
......@@ -225,15 +225,21 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
static_for<0, K, 1>{}([&](auto k) {
static_for<0, M, 1>{}([&](auto m) {
constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(k, m));
constexpr index_t a_offset =
ADesc{}.CalculateOffset(a_origin_idx + make_tuple(k, m));
#if 0
if constexpr(N == 2)
{
constexpr index_t b_offset_0 = BDesc{}.CalculateOffset(make_tuple(k, I0));
constexpr index_t b_offset_1 = BDesc{}.CalculateOffset(make_tuple(k, I1));
constexpr index_t b_offset_0 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0));
constexpr index_t b_offset_1 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1));
constexpr index_t c_offset_0 = CDesc{}.CalculateOffset(make_tuple(m, I0));
constexpr index_t c_offset_1 = CDesc{}.CalculateOffset(make_tuple(m, I1));
constexpr index_t c_offset_0 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0));
constexpr index_t c_offset_1 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1));
amd_assembly_outer_product_1x2(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset_0>{}],
......@@ -243,15 +249,23 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
}
else if constexpr(N == 4)
{
constexpr index_t b_offset_0 = BDesc{}.CalculateOffset(make_tuple(k, I0));
constexpr index_t b_offset_1 = BDesc{}.CalculateOffset(make_tuple(k, I1));
constexpr index_t b_offset_2 = BDesc{}.CalculateOffset(make_tuple(k, I2));
constexpr index_t b_offset_3 = BDesc{}.CalculateOffset(make_tuple(k, I3));
constexpr index_t c_offset_0 = CDesc{}.CalculateOffset(make_tuple(m, I0));
constexpr index_t c_offset_1 = CDesc{}.CalculateOffset(make_tuple(m, I1));
constexpr index_t c_offset_2 = CDesc{}.CalculateOffset(make_tuple(m, I2));
constexpr index_t c_offset_3 = CDesc{}.CalculateOffset(make_tuple(m, I3));
constexpr index_t b_offset_0 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0));
constexpr index_t b_offset_1 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1));
constexpr index_t b_offset_2 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I2));
constexpr index_t b_offset_3 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I3));
constexpr index_t c_offset_0 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0));
constexpr index_t c_offset_1 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1));
constexpr index_t c_offset_2 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I2));
constexpr index_t c_offset_3 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I3));
amd_assembly_outer_product_1x4(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset_0>{}],
......@@ -264,24 +278,18 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
c_buf(Number<c_offset_3>{}));
}
else
#endif
{
static_for<0, N, 1>{}([&](auto n) {
constexpr index_t a_offset =
ADesc{}.CalculateOffset(a_origin_idx + make_tuple(k, m));
constexpr index_t b_offset =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, n));
constexpr index_t c_offset =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, n));
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
amd_assembly_inner_product(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset>{}],
c_buf(Number<c_offset>{}));
#else
c_buf(Number<c_offset>{}) += inner_product_with_conversion<FloatC>{}(
a_buf[Number<a_offset>{}], b_buf[Number<b_offset>{}]);
#endif
});
}
});
......
......@@ -28,7 +28,7 @@
#endif
// launch bounds
#define CK_USE_LAUNCH_BOUNDS 1
#define CK_USE_LAUNCH_BOUNDS 0
#ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment