Commit d51da77a authored by carlushuang's avatar carlushuang
Browse files

now result is correct, everything works (but has scratch)

parent f58137f2
...@@ -28,10 +28,10 @@ using DeviceGemmStreamK = ck::tensor_operation::device::DeviceGemmXdlStreamK ...@@ -28,10 +28,10 @@ using DeviceGemmStreamK = ck::tensor_operation::device::DeviceGemmXdlStreamK
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8>; // < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8>;
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>; // < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>;
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 128, 4, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8>; // < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 128, 4, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8>;
......
...@@ -557,10 +557,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -557,10 +557,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
math::integer_divide_ceil(Number<MPerBlock>{}, cluster_length_reduce.At(I0)); math::integer_divide_ceil(Number<MPerBlock>{}, cluster_length_reduce.At(I0));
constexpr auto NReduceIters = math::integer_divide_ceil( constexpr auto NReduceIters = math::integer_divide_ceil(
Number<NPerBlock>{}, Number<NPerBlock>{},
cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL); cluster_length_reduce.At(I1) *
Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{});
// constexpr auto acc_thread_buf_desc = make_naive_tensor_descriptor_packed(
// make_tuple(Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed( constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{})); make_tuple(I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
...@@ -588,7 +586,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -588,7 +586,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
make_multi_index(0, make_multi_index(0,
0, 0,
0, 0,
-1 * (MReduceIters - 1) * cluster_length_reduce.At(I1) * -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
CBlockTransferScalarPerVector_NWaveNPerXDL); CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto partial_acc_store_step_m = constexpr auto partial_acc_store_step_m =
make_multi_index(0, cluster_length_reduce.At(I0), 0, 0); make_multi_index(0, cluster_length_reduce.At(I0), 0, 0);
...@@ -603,7 +601,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -603,7 +601,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferScalarPerVector_NWaveNPerXDL,
true> true>
acc_buf; acc_buf;
acc_buf.Clear();
// start to compute // start to compute
auto reduction_idx = blockIdx.x - block_mapping.reduction_start_block_idx; auto reduction_idx = blockIdx.x - block_mapping.reduction_start_block_idx;
...@@ -656,11 +653,24 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -656,11 +653,24 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// block synchronization // block synchronization
wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start); wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start);
// if(threadIdx.x == 0) {
// if(reduction_idx == 0){
// printf("(miter:%d, niter:%d, cluster red:%d,%d)bid:%d, rid:%d, os:%d-%d(%d),
// spatial:%d-%d, tid:%d, %d, %d\n",
// MReduceIters(), NReduceIters(), cluster_length_reduce.At(I0).value,
// cluster_length_reduce.At(I1).value, static_cast<int>(blockIdx.x),
// reduction_idx, tile_acc_offset_start, tile_acc_offset_end,
// tile_acc_offset_end - tile_acc_offset_start, spatial_idx[I0],
// spatial_idx[I1], static_cast<int>(threadIdx.x), thread_m_cluster_id,
// thread_n_cluster_id);
// }
using Accumulation = ck::detail:: using Accumulation = ck::detail::
AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, FloatAcc>; AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, FloatAcc>;
static_for<0, MReduceIters, 1>{}([&](auto i_m_reduce) { static_for<0, MReduceIters, 1>{}([&](auto i_m_reduce) {
static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) { static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) {
acc_buf.Clear();
for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++) for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
{ {
auto c_partial_acc_buf = auto c_partial_acc_buf =
...@@ -691,22 +701,24 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -691,22 +701,24 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
acc_buf, acc_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf); c_grid_buf);
if constexpr(NReduceIters != 1)
if constexpr(i_n_reduce != (NReduceIters - 1))
{ {
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, if constexpr(i_n_reduce != (NReduceIters - 1))
partial_acc_load_step_n); {
acc_store.MoveDstSliceWindow( acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
c_grid_desc_mblock_mperblock_nblock_nperblock, partial_acc_load_step_n);
partial_acc_store_step_n); acc_store.MoveDstSliceWindow(
} c_grid_desc_mblock_mperblock_nblock_nperblock,
else partial_acc_store_step_n);
{ }
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, else
partial_acc_load_step_n_reverse); {
acc_store.MoveDstSliceWindow( acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
c_grid_desc_mblock_mperblock_nblock_nperblock, partial_acc_load_step_n_reverse);
partial_acc_store_step_n_reverse); acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n_reverse);
}
} }
}); });
if constexpr(i_m_reduce != MReduceIters - 1) if constexpr(i_m_reduce != MReduceIters - 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment