Commit 16e3f66a authored by ltqin's avatar ltqin
Browse files

fix bug

parent 263589eb
......@@ -54,7 +54,8 @@ using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
#else
< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 7, 1>;
< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 1, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 7, 1>;
// < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 7, 1>;
// < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 32, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 7, 1>;
// < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 7, 1>;
using ADataType = float;
......@@ -87,10 +88,10 @@ template <typename DataType>
std::ostream& show_2d_matrix(std::ostream& os, Tensor<DataType>& matrix)
{
os << "[" << std::endl;
for(int x = 0; x < matrix.mDesc.GetLengths()[0]; x++)
for(size_t x = 0; x < matrix.mDesc.GetLengths()[0]; x++)
{
os << "[";
for(int y = 0; y < matrix.mDesc.GetLengths()[1]; y++)
for(size_t y = 0; y < matrix.mDesc.GetLengths()[1]; y++)
{
os << std::setw(5) << static_cast<float>(matrix(x, y));
}
......@@ -117,7 +118,7 @@ int main(int argc, char* argv[])
#else
ck::index_t M = 16;
ck::index_t N = 16;
ck::index_t K = 8;
ck::index_t K = 32;
ck::index_t StrideA = 8;
ck::index_t StrideB = 8;
......
......@@ -113,7 +113,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
static constexpr auto I7 = Number<7>{};
static constexpr auto BaseMultK0 = 4;
static constexpr auto MultiK0 = 4 * 2;
static constexpr auto MultiK0 = BaseMultK0 * 2;
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
......@@ -192,7 +192,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if(!((K0 / K0PerBlock) % 2 == 0))
if(!((K0 / K0PerBlock) % MultiK0 == 0))
{
return false;
}
......@@ -573,6 +573,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
b_thread_slice_copy_step);
blockwise_gemm.Run(a_block_buf, b_thread_1st_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
// 2nd
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
......@@ -583,8 +584,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_2nd_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
// 3rd
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
......@@ -595,8 +596,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_3rd_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
// 4th
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
......@@ -607,8 +608,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_4th_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
});
block_sync_lds();
......@@ -637,6 +638,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
b_thread_slice_copy_step);
blockwise_gemm.Run(a_block_buf, b_thread_1st_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
// 2nd
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
......@@ -647,8 +649,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_2nd_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
// 3rd
if constexpr(i < MultiK0 - BaseMultK0)
......@@ -662,8 +664,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
}
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_3rd_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
// 4th
if constexpr(i < MultiK0 - BaseMultK0)
......@@ -677,8 +680,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
b_thread_slice_copy_step);
}
blockwise_gemm.MoveABlockSliceWindow();
blockwise_gemm.Run(a_block_buf, b_thread_4th_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
});
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment