Commit 49316982 authored by ThomasNing's avatar ThomasNing
Browse files

get tback the restrict variable name, need to switch out to solve the transpose issue

parent 66a183d7
...@@ -45,7 +45,7 @@ struct GemmPipelineAgBgCrImplBase ...@@ -45,7 +45,7 @@ struct GemmPipelineAgBgCrImplBase
CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const
{ {
// A tile in LDS // A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem); ADataType* __restrict__ p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>(); constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc); auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
......
...@@ -109,18 +109,20 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -109,18 +109,20 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
constexpr auto num_buffer_load_inst = A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num; constexpr auto num_buffer_load_inst = A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num;
constexpr auto num_issue = num_buffer_load_inst;
static_for<0, num_buffer_load_inst, 1>{}([&](auto i) { static_for<0, num_buffer_load_inst, 1>{}([&](auto i) {
ignore = i; ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier( __builtin_amdgcn_sched_group_barrier(
0x100, num_ds_read_inst / num_buffer_load_inst, 0); // DS read : 2 0x100, num_ds_read_inst / num_issue, 0); // DS read : 2
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
__builtin_amdgcn_sched_group_barrier( __builtin_amdgcn_sched_group_barrier(
0x200, num_ds_write_inst / num_buffer_load_inst, 0); // DS write : 1 0x200, num_ds_write_inst / num_issue, 0); // DS write : 1
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
__builtin_amdgcn_sched_group_barrier( __builtin_amdgcn_sched_group_barrier(
0x008, C_MFMA_Inst_Num / num_buffer_load_inst - 3, 0); // MFMA : 5 0x008, C_MFMA_Inst_Num / num_issue - 3, 0); // MFMA : 5
}); });
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} }
...@@ -136,8 +138,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -136,8 +138,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
const BDramBlockWindowTmp& b_dram_block_window_tmp, const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func, const BElementFunction& b_element_func,
index_t num_loop, index_t num_loop,
void* p_smem_0, void* __restrict__ p_smem_0,
void* p_smem_1) const void* __restrict__ p_smem_1) const
{ {
static_assert( static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> && std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
...@@ -266,6 +268,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -266,6 +268,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window); Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window); Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
printf("Tail Num: =====================================\n");
printf("%d \n", static_cast<int>(TailNum));
if(HasHotLoop) if(HasHotLoop)
{ {
// minus 2 because we have ping-pong double buffer. // minus 2 because we have ping-pong double buffer.
...@@ -392,8 +397,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -392,8 +397,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp, const BDramBlockWindowTmp& b_dram_block_window_tmp,
const index_t num_loop, const index_t num_loop,
void* p_smem_0, void* __restrict__ p_smem_0,
void* p_smem_1) const void* __restrict__ p_smem_1) const
{ {
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>( return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp, a_dram_block_window_tmp,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment