Commit 00825b10 authored by aska-0096's avatar aska-0096
Browse files

112T

parent 8c2244af
...@@ -334,6 +334,7 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -334,6 +334,7 @@ struct BlockwiseGemmXdlops_pipeline_v1
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) const index_t num_loop) const
{ {
__builtin_amdgcn_sched_barrier(0);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize()); a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
...@@ -346,7 +347,7 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -346,7 +347,7 @@ struct BlockwiseGemmXdlops_pipeline_v1
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C // Initialize C
c_thread_buf.Clear(); c_thread_buf.Clear();
// __builtin_amdgcn_sched_barrier(0);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
...@@ -361,16 +362,15 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -361,16 +362,15 @@ struct BlockwiseGemmXdlops_pipeline_v1
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, I0, I0), make_tuple(m0, I0, I0, I0),
a_thread_buf); a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B // read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, I0), make_tuple(n0, I0, I0, I0),
b_block_buf, b_block_buf,
b_thread_desc_, b_thread_desc_,
make_tuple(n0, I0, I0, I0), make_tuple(n0, I0, I0, I0),
b_thread_buf); b_thread_buf);
});
}); });
// main body // main body
...@@ -400,16 +400,22 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -400,16 +400,22 @@ struct BlockwiseGemmXdlops_pipeline_v1
a_thread_desc_, a_thread_desc_,
make_tuple(m0, Number<k % 2>{}, I0, I0), make_tuple(m0, Number<k % 2>{}, I0, I0),
a_thread_buf); a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, Number<k % 2>{}, I0, I0),
b_thread_buf);
});
/* Compute N */
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, Number<k % 2>{}, I0, I0),
b_thread_buf);
/* Compute N */
vector_type<FloatAB, KPack> a_thread_vec; vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec; vector_type<FloatAB, KPack> b_thread_vec;
...@@ -435,16 +441,20 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -435,16 +441,20 @@ struct BlockwiseGemmXdlops_pipeline_v1
}); });
}); });
}); });
/*
__builtin_amdgcn_sched_group_barrier(0x020, 2, 0); // VMEM read __builtin_amdgcn_sched_group_barrier(0x020, 2, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 4, 0); // VMEM read __builtin_amdgcn_sched_group_barrier(0x020, 4, 0); // VMEM read
// __builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
// __builtin_amdgcn_sched_group_barrier(0x020, 2, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x100, 3, 0); // DS read __builtin_amdgcn_sched_group_barrier(0x100, 3, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 6, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 6, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 3, 0); // DS read __builtin_amdgcn_sched_group_barrier(0x100, 3, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 6, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 6, 0); // MFMA
*/
__builtin_amdgcn_sched_group_barrier(0x020, 6, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x100, 6, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 16, 0); // MFMA
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
// Wait all wave consume this k-loop data // Wait all wave consume this k-loop data
block_sync_lds(); block_sync_lds();
...@@ -484,8 +494,9 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -484,8 +494,9 @@ struct BlockwiseGemmXdlops_pipeline_v1
}); });
}); });
// Wait all wave produce next k-loop data // Wait all wave produce next k-loop data
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 6, 0); // DS write __builtin_amdgcn_sched_group_barrier(0x200, 6, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
block_sync_lds(); block_sync_lds();
// Here 1 time prefetch read(idx=0) of next K-loop // Here 1 time prefetch read(idx=0) of next K-loop
...@@ -498,15 +509,15 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -498,15 +509,15 @@ struct BlockwiseGemmXdlops_pipeline_v1
make_tuple(m0, I0, I0, I0), make_tuple(m0, I0, I0, I0),
a_thread_buf); a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) { });
// read B static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, // read B
make_tuple(n0, I0, I0, I0), b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
b_block_buf, make_tuple(n0, I0, I0, I0),
b_thread_desc_, b_block_buf,
make_tuple(n0, I0, I0, I0), b_thread_desc_,
b_thread_buf); make_tuple(n0, I0, I0, I0),
}); b_thread_buf);
}); });
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
...@@ -535,15 +546,15 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -535,15 +546,15 @@ struct BlockwiseGemmXdlops_pipeline_v1
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
}); });
// Current best = 109T at float initialization 3840x4096x4096 // Current best = 112T at float initialization 3840x4096x4096
// __builtin_amdgcn_sched_group_barrier(0x020, 4, 0); // VMEM read // __builtin_amdgcn_sched_group_barrier(0x020, 4, 0); // VMEM read
// __builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA // __builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
// __builtin_amdgcn_sched_group_barrier(0x020, 2, 0); // __builtin_amdgcn_sched_group_barrier(0x020, 2, 0); // VMEM read
// __builtin_amdgcn_sched_group_barrier(0x100, 6, 0); // DS read // __builtin_amdgcn_sched_group_barrier(0x100, 6, 0); // DS read
// __builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA // __builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
// __builtin_amdgcn_sched_group_barrier(0x200, 6, 0); // DS write // __builtin_amdgcn_sched_group_barrier(0x200, 6, 0); // DS write
// __builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA // __builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 6, 0); // DS read __builtin_amdgcn_sched_group_barrier(0x100, 6, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
} while(i < (num_loop - 1)); } while(i < (num_loop - 1));
} }
...@@ -561,16 +572,20 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -561,16 +572,20 @@ struct BlockwiseGemmXdlops_pipeline_v1
a_thread_desc_, a_thread_desc_,
make_tuple(m0, Number<k % 2>{}, I0, I0), make_tuple(m0, Number<k % 2>{}, I0, I0),
a_thread_buf); a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, Number<k % 2>{}, I0, I0),
b_thread_buf);
});
/* Compute N */
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, Number<k % 2>{}, I0, I0),
b_thread_buf);
/* Compute N */
vector_type<FloatAB, KPack> a_thread_vec; vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec; vector_type<FloatAB, KPack> b_thread_vec;
...@@ -623,9 +638,7 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -623,9 +638,7 @@ struct BlockwiseGemmXdlops_pipeline_v1
}); });
}); });
}); });
__builtin_amdgcn_sched_group_barrier(0x100, 6, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 16, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
/* Final Compute issue */ /* Final Compute issue */
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment