Commit 57271814 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent f9b92b1e
...@@ -14,60 +14,62 @@ namespace ck { ...@@ -14,60 +14,62 @@ namespace ck {
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <typename ThreadGroup, template <typename ThreadGroup,
typename ElementwiseOperation, typename ElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp,
typename SliceLengths, typename SliceLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder, typename ThreadClusterArrangeOrder,
typename Src0Data, typename SrcDatas,
typename Src1Data, typename DstDatas,
typename Src2Data, typename SrcDescs,
typename DstData, typename DstDescs,
typename Src0Desc,
typename Src1Desc,
typename Src2Desc,
typename DstDesc,
typename DimAccessOrder, typename DimAccessOrder,
index_t VectorDim, index_t VectorDim,
index_t ScalarPerVector, index_t ScalarPerVector,
bool ThreadTransferSrc0ResetCoordinateAfterRun, typename ThreadTransferSrcResetCoordinateAfterRunFlags,
bool ThreadTransferSrc1ResetCoordinateAfterRun, typename ThreadTransferDstResetCoordinateAfterRunFlags,
bool ThreadTransferSrc2ResetCoordinateAfterRun, InMemoryDataOperationEnum... DstInMemOps>
bool ThreadTransferDstResetCoordinateAfterRun>
struct ThreadGroupTensorSliceTransfer_v7 struct ThreadGroupTensorSliceTransfer_v7
{ {
static constexpr auto I0 = Number<0>{}; static constexpr index_t nDim =
static constexpr auto I1 = Number<1>{}; remove_cvref_t<tuple_element_t<0, SrcDescs>>::GetNumOfDimension();
static constexpr auto I2 = Number<2>{};
static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension(); static constexpr index_t nSrc = remove_cvref_t<SrcDescs>::Size();
static constexpr index_t nDst = remove_cvref_t<DstDescs>::Size();
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
__device__ constexpr ThreadGroupTensorSliceTransfer_v7(const Src0Desc& src0_desc, static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
const Index& src0_block_slice_origin,
const Src1Desc& src1_desc,
const Index& src1_block_slice_origin,
const Src2Desc& src2_desc,
const Index& src2_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op)
: threadwise_transfer_(tie(src0_desc, src1_desc, src2_desc),
make_tuple(make_zero_multi_index<nDim>(),
make_zero_multi_index<nDim>(),
make_zero_multi_index<nDim>()),
tie(dst_desc),
make_tuple(make_zero_multi_index<nDim>()),
element_op)
__device__ constexpr ThreadGroupTensorSliceTransfer_v7(
const SrcDescs& src_descs,
const StaticallyIndexedArray<Index, nSrc>& src_block_slice_origins,
const DstDescs& dst_descs,
const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins,
const ElementwiseOperation& element_op)
: threadwise_transfer_(src_descs,
StaticallyIndexedArray<Index, nSrc>{},
dst_descs,
StaticallyIndexedArray<Index, nDst>{},
element_op)
{ {
static_assert(nDim == remove_cvref_t<Src0Desc>::GetNumOfDimension() && static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() &&
nDim == remove_cvref_t<Src1Desc>::GetNumOfDimension() && nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() &&
nDim == remove_cvref_t<Src2Desc>::GetNumOfDimension() && nDst == DstDatas::Size() && nDst == DstDescs::Size() &&
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() && nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(),
nDim == ThreadClusterLengths::Size() && "wrong!");
static_for<0, nSrc, 1>{}([&](auto i) {
static_assert(
nDim == remove_cvref_t<tuple_element_t<i.value, SrcDescs>>::GetNumOfDimension(),
"wrong!");
});
static_for<0, nDst, 1>{}([&](auto i) {
static_assert(
nDim == remove_cvref_t<tuple_element_t<i.value, DstDescs>>::GetNumOfDimension(),
"wrong!");
});
static_assert(nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() && nDim == ThreadClusterArrangeOrder::Size() &&
nDim == DimAccessOrder::Size(), nDim == DimAccessOrder::Size(),
"wrong! nDim not consistent"); "wrong! nDim not consistent");
...@@ -87,73 +89,51 @@ struct ThreadGroupTensorSliceTransfer_v7 ...@@ -87,73 +89,51 @@ struct ThreadGroupTensorSliceTransfer_v7
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrcSliceOrigin( const auto src_thread_slice_origins = generate_tuple(
tie(src0_desc, src1_desc, src2_desc), [&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; },
make_tuple(src0_block_slice_origin + thread_data_idx_begin, Number<nSrc>{});
src1_block_slice_origin + thread_data_idx_begin,
src2_block_slice_origin + thread_data_idx_begin));
threadwise_transfer_.SetDstSliceOrigin( const auto dst_thread_slice_origins = generate_tuple(
tie(dst_desc), make_tuple(dst_block_slice_origin + thread_data_idx_begin)); [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; },
} Number<nDst>{});
}
template <typename Src0Buffer, typename Src1Buffer, typename Src2Buffer, typename DstBuffer>
__device__ void Run(const Src0Desc& src0_desc,
const Src0Buffer& src0_buf,
const Src1Desc& src1_desc,
const Src1Buffer& src1_buf,
const Src2Desc& src2_desc,
const Src2Buffer& src2_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.Run(tie(src0_desc, src1_desc, src2_desc),
tie(src0_buf, src1_buf, src2_buf),
tie(dst_desc),
tie(dst_buf));
}
}
__device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step) threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins);
{ threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins);
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(
tie(src0_desc, Src1Desc{}, Src2Desc{}), step, I0);
} }
} }
__device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step) template <typename SrcBuffers, typename DstBuffers>
__device__ void Run(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers dst_bufs)
{ {
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrcSliceWindow( threadwise_transfer_.Run(src_descs, src_bufs, dst_descs, dst_bufs);
tie(Src0Desc{}, src1_desc, Src2Desc{}), step, I1);
} }
} }
__device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, const Index& step) template <index_t ISrc>
__device__ void
MoveSrcSliceWindow(const SrcDescs& src_descs, Number<ISrc> iSrc, const Index& step)
{ {
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveSrcSliceWindow( threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step);
tie(Src0Desc{}, Src1Desc{}, src2_desc), step, I2);
} }
} }
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) template <index_t IDst>
__device__ void
MoveDstSliceWindow(const DstDescs& dst_descs, Number<IDst> iDst, const Index& step)
{ {
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{ {
threadwise_transfer_.MoveDstSliceWindow(tie(dst_desc), step, I0); threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step);
} }
} }
...@@ -161,23 +141,19 @@ struct ThreadGroupTensorSliceTransfer_v7 ...@@ -161,23 +141,19 @@ struct ThreadGroupTensorSliceTransfer_v7
static constexpr auto thread_cluster_desc_ = static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer = ThreadwiseTensorSliceTransfer_v7< using ThreadwiseTransfer =
Tuple<remove_cvref_t<Src0Data>, remove_cvref_t<Src1Data>, remove_cvref_t<Src2Data>>, ThreadwiseTensorSliceTransfer_v7<SrcDatas,
Tuple<remove_cvref_t<DstData>>, DstDatas,
Tuple<remove_reference_t<Src0Desc>&, SrcDescs,
remove_reference_t<Src1Desc>&, DstDescs,
remove_reference_t<Src2Desc>&>, ElementwiseOperation,
Tuple<remove_reference_t<DstDesc>&>, decltype(thread_slice_lengths),
ElementwiseOperation, DimAccessOrder,
decltype(thread_slice_lengths), VectorDim,
DimAccessOrder, ScalarPerVector,
VectorDim, ThreadTransferSrcResetCoordinateAfterRunFlags,
ScalarPerVector, ThreadTransferDstResetCoordinateAfterRunFlags,
Sequence<ThreadTransferSrc0ResetCoordinateAfterRun, DstInMemOps...>;
ThreadTransferSrc1ResetCoordinateAfterRun,
ThreadTransferSrc2ResetCoordinateAfterRun>,
Sequence<ThreadTransferDstResetCoordinateAfterRun>,
DstInMemOp>;
ThreadwiseTransfer threadwise_transfer_; ThreadwiseTransfer threadwise_transfer_;
}; };
......
...@@ -542,77 +542,39 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -542,77 +542,39 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
ck::tensor_operation::element_wise::PassThrough{}}; ck::tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global // shuffle: blockwise copy C from LDS to global
#if 0 // FIXME: arbitrary # of D tensors
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< const auto c_ds_descs = tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
ThisThreadBlock, // ThreadGroup ds_grid_desc_mblock_mperblock_nblock_nperblock[I0],
CDEElementwiseOperation, // ElementwiseOperation, ds_grid_desc_mblock_mperblock_nblock_nperblock[I1]);
EGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename Src0Data,
remove_cvref_t<decltype(DsDataType{}[I0])>, // typename Src1Data,
remove_cvref_t<decltype(DsDataType{}[I1])>, // typename Src2Data,
FloatE, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I0]),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I1]),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrc0ResetCoordinateAfterRun,
false, // bool ThreadTransferSrc1ResetCoordinateAfterRun,
false, // bool ThreadTransferSrc2ResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
ds_grid_desc_mblock_mperblock_nblock_nperblock[I0],
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
ds_grid_desc_mblock_mperblock_nblock_nperblock[I1],
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
e_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
cde_element_op};
#else
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock, // ThreadGroup ThisThreadBlock, // ThreadGroup
CDEElementwiseOperation, // ElementwiseOperation, CDEElementwiseOperation, // ElementwiseOperation,
EGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1, Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1, 1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename Src0Data, Tuple<FloatCShuffle,
remove_cvref_t<tuple_element_t<0, DsDataType>>, // typename Src1Data, remove_cvref_t<tuple_element_t<0, DsDataType>>,
remove_cvref_t<tuple_element_t<1, DsDataType>>, // typename Src2Data, remove_cvref_t<tuple_element_t<1, DsDataType>>>,
FloatE, // typename DstData, Tuple<FloatE>, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), decltype(c_ds_descs),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I0]), decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I1]),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder, Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim, 3, // index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrc0ResetCoordinateAfterRun, Sequence<true, false, false>, // bool ThreadTransferSrcResetCoordinateAfterRunFlags
false, // bool ThreadTransferSrc1ResetCoordinateAfterRun, Sequence<false>, // bool ThreadTransferDstResetCoordinateAfterRunFlags
false, // bool ThreadTransferSrc2ResetCoordinateAfterRun, EGlobalMemoryDataOperation> // DstInMemOp,
false> // bool ThreadTransferDstResetCoordinateAfterRun> {c_ds_descs,
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, make_tuple(make_multi_index(0, 0, 0, 0),
make_multi_index(0, 0, 0, 0), make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
ds_grid_desc_mblock_mperblock_nblock_nperblock[I0], make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
ds_grid_desc_mblock_mperblock_nblock_nperblock[I1], make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
e_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
cde_element_op}; cde_element_op};
#endif
// space filling curve for threadwise C in VGPR before shuffle // space filling curve for threadwise C in VGPR before shuffle
constexpr auto sfc_c_vgpr = constexpr auto sfc_c_vgpr =
...@@ -655,42 +617,25 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -655,42 +617,25 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
block_sync_lds(); block_sync_lds();
// each block copy its data from LDS to global // each block copy its data from LDS to global
#if 1
cde_block_copy_lds_and_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
ds_grid_desc_mblock_mperblock_nblock_nperblock[I0],
ds_grid_buf[I0],
ds_grid_desc_mblock_mperblock_nblock_nperblock[I1],
ds_grid_buf[I1],
e_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_buf);
#else
cde_block_copy_lds_and_global.Run( cde_block_copy_lds_and_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, c_ds_descs,
c_shuffle_block_buf, tie(c_shuffle_block_buf, ds_grid_buf[I0], ds_grid_buf[I1]),
ds_grid_desc_mblock_mperblock_nblock_nperblock[I0], tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
ds_grid_buf[I0], tie(e_grid_buf));
ds_grid_desc_mblock_mperblock_nblock_nperblock[I1],
ds_grid_buf[I1],
e_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_buf);
#endif
if constexpr(access_id < num_access - 1) if constexpr(access_id < num_access - 1)
{ {
constexpr auto c_global_step = sfc_cde_block.GetForwardStep(access_id); constexpr auto e_global_step = sfc_cde_block.GetForwardStep(access_id);
// move on Ds // move on Ds
cde_block_copy_lds_and_global.MoveSrc1SliceWindow( static_for<0, DsDataType::Size(), 1>{}([&](auto i) {
ds_grid_desc_mblock_mperblock_nblock_nperblock[I0], c_global_step); cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_descs, i + I1, e_global_step);
cde_block_copy_lds_and_global.MoveSrc2SliceWindow( });
ds_grid_desc_mblock_mperblock_nblock_nperblock[I1], c_global_step);
// move on E // move on E
cde_block_copy_lds_and_global.MoveDstSliceWindow( cde_block_copy_lds_and_global.MoveDstSliceWindow(
e_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); tie(e_grid_desc_mblock_mperblock_nblock_nperblock), I0, e_global_step);
} }
}); });
} }
......
...@@ -80,8 +80,8 @@ struct ThreadwiseTensorSliceTransfer_v7 ...@@ -80,8 +80,8 @@ struct ThreadwiseTensorSliceTransfer_v7
} }
template <typename Indices, enable_if_t<SrcDescs::Size() == Indices::Size(), bool> = false> template <typename Indices, enable_if_t<SrcDescs::Size() == Indices::Size(), bool> = false>
__device__ void SetSrcSliceOrigin(const SrcDescs& src_descs, __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs,
const Indices& src_slice_origin_idxs) const Indices& src_slice_origin_idxs)
{ {
static_for<0, nSrc, 1>{}([&](auto i) { static_for<0, nSrc, 1>{}([&](auto i) {
src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]); src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]);
...@@ -89,8 +89,8 @@ struct ThreadwiseTensorSliceTransfer_v7 ...@@ -89,8 +89,8 @@ struct ThreadwiseTensorSliceTransfer_v7
} }
template <typename Indices, enable_if_t<DstDescs::Size() == Indices::Size(), bool> = false> template <typename Indices, enable_if_t<DstDescs::Size() == Indices::Size(), bool> = false>
__device__ void SetDstSliceOrigin(const DstDescs& dst_descs, __device__ void SetDstSliceOrigins(const DstDescs& dst_descs,
const Indices& dst_slice_origin_idxs) const Indices& dst_slice_origin_idxs)
{ {
static_for<0, nDst, 1>{}([&](auto i) { static_for<0, nDst, 1>{}([&](auto i) {
dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]); dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]);
...@@ -234,8 +234,8 @@ struct ThreadwiseTensorSliceTransfer_v7 ...@@ -234,8 +234,8 @@ struct ThreadwiseTensorSliceTransfer_v7
// src_slice_origin_step_idx need to be known at compile-time, for performance reason // src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <index_t ISrc> template <index_t ISrc>
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs,
const Index& src_slice_origin_step_idx, Number<ISrc> iSrc,
Number<ISrc> iSrc) const Index& src_slice_origin_step_idx)
{ {
// if src coord was not reset by RunRead(), then need to adjust the step here // if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx = SrcResetCoordinateAfterRunFlags::At(iSrc) const auto adjusted_step_idx = SrcResetCoordinateAfterRunFlags::At(iSrc)
...@@ -251,8 +251,8 @@ struct ThreadwiseTensorSliceTransfer_v7 ...@@ -251,8 +251,8 @@ struct ThreadwiseTensorSliceTransfer_v7
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
template <index_t IDst> template <index_t IDst>
__device__ void MoveDstSliceWindow(const DstDescs& dst_descs, __device__ void MoveDstSliceWindow(const DstDescs& dst_descs,
const Index& dst_slice_origin_step_idx, Number<IDst> iDst,
Number<IDst> iDst) const Index& dst_slice_origin_step_idx)
{ {
// if dst coord was not reset by Run(), then need to adjust the step here // if dst coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx = DstResetCoordinateAfterRunFlags::At(iDst) const auto adjusted_step_idx = DstResetCoordinateAfterRunFlags::At(iDst)
...@@ -265,22 +265,6 @@ struct ThreadwiseTensorSliceTransfer_v7 ...@@ -265,22 +265,6 @@ struct ThreadwiseTensorSliceTransfer_v7
move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step); move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step);
} }
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveAllSrcSliceWindow(const SrcDescs& src_descs,
const Index& src_slice_origin_step_idx)
{
static_for<0, nSrc, 1>{}(
[&](auto i) { MoveSrcSliceWindow(src_descs, src_slice_origin_step_idx, i); });
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveAllDstSliceWindow(const DstDescs& dst_descs,
const Index& dst_slice_origin_step_idx)
{
static_for<0, nDst, 1>{}(
[&](auto i) { MoveDstSliceWindow(dst_descs, dst_slice_origin_step_idx, i); });
}
private: private:
SrcCoords src_coords_; SrcCoords src_coords_;
DstCoords dst_coords_; DstCoords dst_coords_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment