"csrc/git@developer.sourcefind.cn:OpenDAS/torch-cluster.git" did not exist on "547759a6b996884d48b9aa4d5c680a6e7e284c20"
Commit 8c2244af authored by aska-0096's avatar aska-0096
Browse files

temp save

parent 737f5f25
...@@ -66,6 +66,7 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -66,6 +66,7 @@ struct BlockwiseGemmXdlops_pipeline_v1
static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, TransposeC>{}; static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, TransposeC>{};
static constexpr auto xdlops_gemm_sp = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack/2, TransposeC>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack; static constexpr index_t KRepeat = KPerThread / KPack;
...@@ -343,7 +344,9 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -343,7 +344,9 @@ struct BlockwiseGemmXdlops_pipeline_v1
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
// __builtin_amdgcn_sched_barrier(0);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
...@@ -432,11 +435,17 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -432,11 +435,17 @@ struct BlockwiseGemmXdlops_pipeline_v1
}); });
}); });
}); });
__builtin_amdgcn_sched_group_barrier(0x020, 2, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 4, 0); // VMEM read
// __builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
// __builtin_amdgcn_sched_group_barrier(0x020, 2, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x100, 3, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 6, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 3, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 6, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
// Wait all wave consume this k-loop data // Wait all wave consume this k-loop data
// __syncthreads();
// __builtin_amdgcn_s_waitcnt(0);
// __builtin_amdgcn_s_barrier();
block_sync_lds(); block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
...@@ -453,7 +462,7 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -453,7 +462,7 @@ struct BlockwiseGemmXdlops_pipeline_v1
vector_type<FloatAB, KPack> a_thread_vec; vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec; vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack/2, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) = a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, (KRepeat - 1) % 2, 0, ik))>{}]; make_tuple(m0, (KRepeat - 1) % 2, 0, ik))>{}];
...@@ -468,16 +477,16 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -468,16 +477,16 @@ struct BlockwiseGemmXdlops_pipeline_v1
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run( xdlops_gemm_sp.template Run(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
}); });
// Wait all wave produce next k-loop data // Wait all wave produce next k-loop data
// __syncthreads(); __builtin_amdgcn_sched_group_barrier(0x200, 6, 0); // DS write
// __builtin_amdgcn_s_waitcnt(0); __builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
// __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0);
block_sync_lds(); block_sync_lds();
// Here 1 time prefetch read(idx=0) of next K-loop // Here 1 time prefetch read(idx=0) of next K-loop
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
...@@ -499,12 +508,43 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -499,12 +508,43 @@ struct BlockwiseGemmXdlops_pipeline_v1
b_thread_buf); b_thread_buf);
}); });
}); });
__builtin_amdgcn_sched_group_barrier(0x020, 6, 0); // VMEM read static_for<0, MRepeat, 1>{}([&](auto m0) {
__builtin_amdgcn_sched_group_barrier(0x100, 6, 0); // DS read static_for<0, NRepeat, 1>{}([&](auto n0) {
__builtin_amdgcn_sched_group_barrier(0x008, 16, 0); // MFMA /* Compute N */
__builtin_amdgcn_sched_group_barrier(0x200, 6, 0); // DS write vector_type<FloatAB, KPack> a_thread_vec;
__builtin_amdgcn_sched_group_barrier(0x008, 16, 0); // MFMA vector_type<FloatAB, KPack> b_thread_vec;
__builtin_amdgcn_sched_group_barrier(0x100, 6, 0); // DS read
static_for<0, KPack/2, 1>{}([&](auto ik) {
a_thread_vec.template AsType<FloatAB>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, (KRepeat - 1) % 2, 0, ik+KPack/2))>{}];
b_thread_vec.template AsType<FloatAB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, (KRepeat - 1) % 2, 0, ik+KPack/2))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm_sp.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
// Current best = 109T at float initialization 3840x4096x4096
// __builtin_amdgcn_sched_group_barrier(0x020, 4, 0); // VMEM read
// __builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
// __builtin_amdgcn_sched_group_barrier(0x020, 2, 0);
// __builtin_amdgcn_sched_group_barrier(0x100, 6, 0); // DS read
// __builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
// __builtin_amdgcn_sched_group_barrier(0x200, 6, 0); // DS write
// __builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 6, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
} while(i < (num_loop - 1)); } while(i < (num_loop - 1));
} }
...@@ -583,6 +623,9 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -583,6 +623,9 @@ struct BlockwiseGemmXdlops_pipeline_v1
}); });
}); });
}); });
__builtin_amdgcn_sched_group_barrier(0x100, 6, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 16, 0); // MFMA
__builtin_amdgcn_sched_barrier(0);
/* Final Compute issue */ /* Final Compute issue */
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment