Commit 3cf22191 authored by Jing Zhang's avatar Jing Zhang
Browse files

clean

parent 5d76eccd
...@@ -159,7 +159,8 @@ struct ThreadGroupTensorSliceTransfer_v7r2 ...@@ -159,7 +159,8 @@ struct ThreadGroupTensorSliceTransfer_v7r2
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step) __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step)
{ {
MoveSrcSliceWindow(src_descs, Number<0>{}, step); static_for<0, SrcDescs::Size(), 1>{}(
[&](auto i) { MoveSrcSliceWindow(src_descs, i, step); });
} }
template <index_t IDst> template <index_t IDst>
...@@ -175,7 +176,8 @@ struct ThreadGroupTensorSliceTransfer_v7r2 ...@@ -175,7 +176,8 @@ struct ThreadGroupTensorSliceTransfer_v7r2
__device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step) __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step)
{ {
MoveDstSliceWindow(dst_descs, Number<0>{}, step); static_for<0, DstDescs::Size(), 1>{}(
[&](auto i) { MoveDstSliceWindow(dst_descs, i, step); });
} }
private: private:
......
...@@ -705,13 +705,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -705,13 +705,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
a_blockwise_copy.RunRead(a_grid_desc, a_grid_bufs); a_blockwise_copy.RunRead(a_grid_desc, a_grid_bufs);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_bufs); b_blockwise_copy.RunRead(b_grid_desc, b_grid_bufs);
static_for<0, NumATensor, 1>{}([&](auto i) { a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, i, a_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
});
static_for<0, NumBTensor, 1>{}([&](auto i) {
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, i, b_block_copy_step);
});
// Initialize C // Initialize C
c_thread_buf.Clear(); c_thread_buf.Clear();
...@@ -738,13 +733,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -738,13 +733,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
block_sync_lds(); block_sync_lds();
static_for<0, NumATensor, 1>{}([&](auto i) { a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, i, a_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
});
static_for<0, NumBTensor, 1>{}([&](auto i) {
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, i, b_block_copy_step);
});
a_blockwise_copy.RunWrite(tie(a_block_desc), tie(a_block_buf)); a_blockwise_copy.RunWrite(tie(a_block_desc), tie(a_block_buf));
b_blockwise_copy.RunWrite(tie(b_block_desc), tie(b_block_buf)); b_blockwise_copy.RunWrite(tie(b_block_desc), tie(b_block_buf));
...@@ -760,22 +750,21 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -760,22 +750,21 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
} }
} }
#endif #else
#if 0
// gridwise GEMM pipeline // gridwise GEMM pipeline
const auto gridwise_gemm_pipeline = const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>(); GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_descs, gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(as_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
a_blockwise_copy, a_blockwise_copy,
a_grid_bufs, as_grid_buf,
a_block_buf, a_block_buf,
a_block_slice_copy_step, a_block_slice_copy_step,
b_grid_descs, bs_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1,
b_blockwise_copy, b_blockwise_copy,
b_grid_bufs, bs_grid_buf,
b_block_buf, b_block_buf,
b_block_slice_copy_step, b_block_slice_copy_step,
blockwise_gemm, blockwise_gemm,
......
...@@ -167,17 +167,31 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -167,17 +167,31 @@ struct ThreadwiseTensorSliceTransfer_v7r2
is_src_valid); is_src_valid);
}); });
if constexpr(!has_vec_len<decltype(element_op_)>::value) if constexpr(has_vec_len<decltype(element_op_)>::value)
{ {
constexpr auto elem_op_vec_len = decltype(element_op_)::vec_len;
static_assert(is_same<remove_cvref_t<decltype(elem_op_vec_len)>, index_t>::value,
"vec_len in element_op_ type is not index_t");
static_assert(elem_op_vec_len == 2 || elem_op_vec_len == 4 || elem_op_vec_len == 8,
"vec_len in element_op_ must be 2, 4, 8");
static_assert(SrcScalarPerVector % elem_op_vec_len == 0,
"vec_len in element_op_ cannot be divided by SrcScalarPerVector!");
// apply pointwise function // apply pointwise function
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) {
// get reference to src data // get reference to src data
const auto src_data_refs = generate_tie( const auto src_data_refs = generate_tie(
// return type should be lvalue // return type should be lvalue
[&](auto iSrc) -> const auto& { [&](auto iSrc) -> const auto& {
using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>; using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
return src_vectors[iSrc].template AsType<SrcData>()[i]; using elem_op_vec_t =
typename vector_type<SrcData, elem_op_vec_len>::type;
return src_vectors[iSrc].template AsType<elem_op_vec_t>()[i];
}, },
Number<nSrc>{}); Number<nSrc>{});
...@@ -187,7 +201,10 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -187,7 +201,10 @@ struct ThreadwiseTensorSliceTransfer_v7r2
[&](auto iDst) -> auto& { [&](auto iDst) -> auto& {
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>; using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
return dst_vectors(iDst).template AsType<DstData>()(i); using elem_op_vec_t =
typename vector_type<DstData, elem_op_vec_len>::type;
return dst_vectors(iDst).template AsType<elem_op_vec_t>()(i);
}, },
Number<nDst>{}); Number<nDst>{});
...@@ -204,29 +221,15 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -204,29 +221,15 @@ struct ThreadwiseTensorSliceTransfer_v7r2
} }
else else
{ {
constexpr auto elem_op_vec_len = decltype(element_op_)::vec_len;
static_assert(is_same<remove_cvref_t<decltype(elem_op_vec_len)>, index_t>::value,
"vec_len in element_op_ type is not index_t");
static_assert(elem_op_vec_len == 2 || elem_op_vec_len == 4 || elem_op_vec_len == 8,
"vec_len in element_op_ must be 2, 4, 8");
static_assert(SrcScalarPerVector % elem_op_vec_len == 0,
"vec_len in element_op_ cannot be divided by SrcScalarPerVector!");
// apply pointwise function // apply pointwise function
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
// get reference to src data // get reference to src data
const auto src_data_refs = generate_tie( const auto src_data_refs = generate_tie(
// return type should be lvalue // return type should be lvalue
[&](auto iSrc) -> const auto& { [&](auto iSrc) -> const auto& {
using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>; using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
using elem_op_vec_t = return src_vectors[iSrc].template AsType<SrcData>()[i];
typename vector_type<SrcData, elem_op_vec_len>::type;
return src_vectors[iSrc].template AsType<elem_op_vec_t>()[i];
}, },
Number<nSrc>{}); Number<nSrc>{});
...@@ -236,10 +239,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -236,10 +239,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
[&](auto iDst) -> auto& { [&](auto iDst) -> auto& {
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>; using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
using elem_op_vec_t = return dst_vectors(iDst).template AsType<DstData>()(i);
typename vector_type<DstData, elem_op_vec_len>::type;
return dst_vectors(iDst).template AsType<elem_op_vec_t>()(i);
}, },
Number<nDst>{}); Number<nDst>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment