Commit 3cf22191 authored by Jing Zhang's avatar Jing Zhang
Browse files

clean

parent 5d76eccd
......@@ -159,7 +159,8 @@ struct ThreadGroupTensorSliceTransfer_v7r2
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step)
{
MoveSrcSliceWindow(src_descs, Number<0>{}, step);
static_for<0, SrcDescs::Size(), 1>{}(
[&](auto i) { MoveSrcSliceWindow(src_descs, i, step); });
}
template <index_t IDst>
......@@ -175,7 +176,8 @@ struct ThreadGroupTensorSliceTransfer_v7r2
__device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step)
{
MoveDstSliceWindow(dst_descs, Number<0>{}, step);
static_for<0, DstDescs::Size(), 1>{}(
[&](auto i) { MoveDstSliceWindow(dst_descs, i, step); });
}
private:
......
......@@ -705,13 +705,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
a_blockwise_copy.RunRead(a_grid_desc, a_grid_bufs);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_bufs);
static_for<0, NumATensor, 1>{}([&](auto i) {
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, i, a_block_copy_step);
});
static_for<0, NumBTensor, 1>{}([&](auto i) {
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, i, b_block_copy_step);
});
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
......@@ -738,13 +733,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
block_sync_lds();
static_for<0, NumATensor, 1>{}([&](auto i) {
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, i, a_block_copy_step);
});
static_for<0, NumBTensor, 1>{}([&](auto i) {
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, i, b_block_copy_step);
});
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(tie(a_block_desc), tie(a_block_buf));
b_blockwise_copy.RunWrite(tie(b_block_desc), tie(b_block_buf));
......@@ -760,22 +750,21 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
#endif
#if 0
#else
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_descs,
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(as_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_bufs,
as_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_descs,
bs_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_bufs,
bs_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
......
......@@ -167,17 +167,31 @@ struct ThreadwiseTensorSliceTransfer_v7r2
is_src_valid);
});
if constexpr(!has_vec_len<decltype(element_op_)>::value)
if constexpr(has_vec_len<decltype(element_op_)>::value)
{
constexpr auto elem_op_vec_len = decltype(element_op_)::vec_len;
static_assert(is_same<remove_cvref_t<decltype(elem_op_vec_len)>, index_t>::value,
"vec_len in element_op_ type is not index_t");
static_assert(elem_op_vec_len == 2 || elem_op_vec_len == 4 || elem_op_vec_len == 8,
"vec_len in element_op_ must be 2, 4, 8");
static_assert(SrcScalarPerVector % elem_op_vec_len == 0,
"vec_len in element_op_ cannot be divided by SrcScalarPerVector!");
// apply pointwise function
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) {
// get reference to src data
const auto src_data_refs = generate_tie(
// return type should be lvalue
[&](auto iSrc) -> const auto& {
using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
return src_vectors[iSrc].template AsType<SrcData>()[i];
using elem_op_vec_t =
typename vector_type<SrcData, elem_op_vec_len>::type;
return src_vectors[iSrc].template AsType<elem_op_vec_t>()[i];
},
Number<nSrc>{});
......@@ -187,7 +201,10 @@ struct ThreadwiseTensorSliceTransfer_v7r2
[&](auto iDst) -> auto& {
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
return dst_vectors(iDst).template AsType<DstData>()(i);
using elem_op_vec_t =
typename vector_type<DstData, elem_op_vec_len>::type;
return dst_vectors(iDst).template AsType<elem_op_vec_t>()(i);
},
Number<nDst>{});
......@@ -204,29 +221,15 @@ struct ThreadwiseTensorSliceTransfer_v7r2
}
else
{
constexpr auto elem_op_vec_len = decltype(element_op_)::vec_len;
static_assert(is_same<remove_cvref_t<decltype(elem_op_vec_len)>, index_t>::value,
"vec_len in element_op_ type is not index_t");
static_assert(elem_op_vec_len == 2 || elem_op_vec_len == 4 || elem_op_vec_len == 8,
"vec_len in element_op_ must be 2, 4, 8");
static_assert(SrcScalarPerVector % elem_op_vec_len == 0,
"vec_len in element_op_ cannot be divided by SrcScalarPerVector!");
// apply pointwise function
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) {
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
// get reference to src data
const auto src_data_refs = generate_tie(
// return type should be lvalue
[&](auto iSrc) -> const auto& {
using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
using elem_op_vec_t =
typename vector_type<SrcData, elem_op_vec_len>::type;
return src_vectors[iSrc].template AsType<elem_op_vec_t>()[i];
return src_vectors[iSrc].template AsType<SrcData>()[i];
},
Number<nSrc>{});
......@@ -236,10 +239,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
[&](auto iDst) -> auto& {
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
using elem_op_vec_t =
typename vector_type<DstData, elem_op_vec_len>::type;
return dst_vectors(iDst).template AsType<elem_op_vec_t>()(i);
return dst_vectors(iDst).template AsType<DstData>()(i);
},
Number<nDst>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment