"...composable_kernel.git" did not exist on "b491ebf38480bc0d6cb329ba6825dee610c59097"
Commit 7cec63a6 authored by aska-0096's avatar aska-0096
Browse files

remove agpr usage when vgpr usage <256

parent 1a089f6f
...@@ -418,13 +418,65 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -418,13 +418,65 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
i += 1; i += 1;
} while(i < (num_loop - 1)); } while(i < (num_loop - 2));
} }
// tail // tail
if constexpr(TailNum == TailNumber::Full) if constexpr(TailNum == TailNumber::Full)
{ {
a_thread_buf_tail = a_thread_buf;
b_thread_buf_tail = b_thread_buf; block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k0, I0),
a_thread_buf_tail);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k0, I0),
b_thread_buf_tail);
});
});
} }
} }
...@@ -580,11 +632,66 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -580,11 +632,66 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
i += 1; i += 1;
} while(i < (num_loop - 1)); } while(i < (num_loop - 2));
} }
// tail // tail
if constexpr(TailNum == TailNumber::Full) if constexpr(TailNum == TailNumber::Full)
{ {
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k0, I0),
b_thread_buf);
});
});
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
......
...@@ -234,13 +234,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo ...@@ -234,13 +234,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo
} }
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{ {
constexpr index_t instance_lds_size = return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
MPerBlock * KPerBlock * sizeof(ADataType) +
NPerBlock * KPerBlock * sizeof(BDataType);
return ((MPerBlock * NPerBlock / BlockSize <= 128) &&
(instance_lds_size <= 32768))
? 2
: 1;
} }
else else
{ {
......
...@@ -212,13 +212,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -212,13 +212,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
} }
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{ {
constexpr index_t instance_lds_size = return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
MPerBlock * KPerBlock * sizeof(ADataType) +
NPerBlock * KPerBlock * sizeof(BDataType);
return ((MPerBlock * NPerBlock / BlockSize <= 128) &&
(instance_lds_size <= 32768))
? 2
: 1;
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment