Unverified Commit 0a7174ad authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Merge with (not the latest) upstream CK (#32)

* fix build for old ck examples

* fix build for old ck
parent 496be40e
...@@ -24,6 +24,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o ...@@ -24,6 +24,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
......
...@@ -25,6 +25,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o ...@@ -25,6 +25,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
......
...@@ -24,6 +24,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o ...@@ -24,6 +24,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
......
...@@ -28,6 +28,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o ...@@ -28,6 +28,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
......
...@@ -24,6 +24,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o ...@@ -24,6 +24,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
......
...@@ -54,7 +54,7 @@ struct GemmGemm ...@@ -54,7 +54,7 @@ struct GemmGemm
// block gemm1 // block gemm1
using BlockGemm1 = ck::tile_program::block::BlockGemmARegBSmemCRegV1< using BlockGemm1 = ck::tile_program::block::BlockGemmARegBSmemCRegV1<
ck::tile_program::block::BlockGemmARegBSmemCRegV1Problem< ck::tile_program::block::BlockGemmARegBSmemCRegProblem<
C0DataType, C0DataType,
B1DataType, B1DataType,
Acc1DataType, Acc1DataType,
......
...@@ -544,9 +544,9 @@ struct Merge_v2_magic_division : public BaseTransform<LowLengths::Size(), 1> ...@@ -544,9 +544,9 @@ struct Merge_v2_magic_division : public BaseTransform<LowLengths::Size(), 1>
using UpLengths = using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
using LowLengthsMagicDivisor = decltype( using LowLengthsMagicDivisor = decltype(generate_tuple(
generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_divisor<LowLengths>{}, lambda_merge_generate_MagicDivision_calculate_magic_divisor<LowLengths>{},
Number<NDimLow>{})); Number<NDimLow>{}));
LowLengths low_lengths_; LowLengths low_lengths_;
LowLengthsMagicDivisor low_lengths_magic_divisor_; LowLengthsMagicDivisor low_lengths_magic_divisor_;
...@@ -986,6 +986,73 @@ struct Freeze : public BaseTransform<1, 0> ...@@ -986,6 +986,73 @@ struct Freeze : public BaseTransform<1, 0>
} }
}; };
// Insert a dangling upper dimension without lower dimension
template <typename UpperLength>
struct Insert : public BaseTransform<0, 1>
{
using UpLengths = decltype(make_tuple(UpperLength{}));
UpLengths up_lengths_;
__host__ __device__ constexpr Insert() = default;
__host__ __device__ constexpr Insert(const UpperLength& up_length)
: up_lengths_{make_tuple(up_length)}
{
}
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 0; }
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
__host__ __device__ constexpr auto GetUpperLengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx&, const UpIdx&) const
{
static_assert(LowIdx::Size() == 0 && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
__host__ __device__ static void
UpdateLowerIndex(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&)
{
static_assert(LowIdxDiff::Size() == 0 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 0 &&
UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");
}
__host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}
template <typename UpIdx>
__host__ __device__ static constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
{
return true;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<UpperLength>::value;
}
__host__ __device__ void Print() const
{
printf("Insert{");
//
print(up_lengths_);
printf("}");
}
};
// Replicate the original tensor and create a higher dimensional tensor // Replicate the original tensor and create a higher dimensional tensor
template <typename UpLengths> template <typename UpLengths>
struct Replicate : public BaseTransform<0, UpLengths::Size()> struct Replicate : public BaseTransform<0, UpLengths::Size()>
......
...@@ -85,6 +85,12 @@ __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_i ...@@ -85,6 +85,12 @@ __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_i
return Freeze<LowerIndex>{low_idx}; return Freeze<LowerIndex>{low_idx};
} }
template <typename UpperIndex>
__host__ __device__ constexpr auto make_insert_transform(const UpperIndex& up_idx)
{
return Insert<UpperIndex>{up_idx};
}
template <typename UpLengths> template <typename UpLengths>
__host__ __device__ constexpr auto make_replicate_transform(const UpLengths& up_lengths) __host__ __device__ constexpr auto make_replicate_transform(const UpLengths& up_lengths)
{ {
......
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
#include "ck/utility/sequence_helper.hpp" #include "ck/utility/sequence_helper.hpp"
#include "ck/utility/multi_index.hpp" #include "ck/utility/multi_index.hpp"
#include "ck/utility/tuple_helper.hpp" #include "ck/utility/tuple_helper.hpp"
#include "ck/tensor_description/multi_index_transform.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tensor_description/tensor_adaptor.hpp"
namespace ck { namespace ck {
......
...@@ -154,9 +154,9 @@ struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_ ...@@ -154,9 +154,9 @@ struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_
: c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( : c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id())}, get_thread_local_1d_id())},
a_thread_copy_{ a_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1], 0)}, make_multi_index(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1], 0)},
b_thread_copy_{ b_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3], 0)} make_multi_index(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3], 0)}
{ {
static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() && static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(), BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
......
...@@ -74,7 +74,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 ...@@ -74,7 +74,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]); return make_multi_index(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]);
} }
__device__ static auto CalculateBThreadOriginDataIndex() __device__ static auto CalculateBThreadOriginDataIndex()
...@@ -85,7 +85,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 ...@@ -85,7 +85,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]); return make_multi_index(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]);
} }
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i> template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
...@@ -110,9 +110,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 ...@@ -110,9 +110,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
make_tuple(Sequence<0, 1, 2>{})); make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; make_multi_index(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; make_multi_index(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n); return make_tuple(c_thread_m, c_thread_n);
} }
......
...@@ -881,8 +881,8 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -881,8 +881,8 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
false, // TransposeC false, // TransposeC
Gemm1KPack, // AMmaKStride Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl, Gemm1KPack, false>{} Gemm1KPack * XdlopsGemm<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl, Gemm1KPack, false>{}
.K0PerXdlops>{ // BMmaKStride .K0PerXdlops>{ // BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin make_multi_index(0, 0, 0, 0)}; // A_origin
auto c1_thread_buf = blockwise_gemm1.GetCThreadBuffer(); auto c1_thread_buf = blockwise_gemm1.GetCThreadBuffer();
......
...@@ -972,7 +972,7 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -972,7 +972,7 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
reduce_grid_desc_mblock_mperblock, reduce_grid_desc_mblock_mperblock,
make_tuple(c_global_step[I0], c_global_step[I1])); make_multi_index(c_global_step[I0], c_global_step[I1]));
} }
}); });
} }
......
...@@ -933,7 +933,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -933,7 +933,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id); constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id);
reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
rs_grid_desc_mblock_mperblock[Ir], rs_grid_desc_mblock_mperblock[Ir],
make_tuple(de_global_step[I0], de_global_step[I1])); make_multi_index(de_global_step[I0], de_global_step[I1]));
} }
}); });
}); // copy c, d, e + reduction }); // copy c, d, e + reduction
......
...@@ -858,7 +858,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -858,7 +858,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
reduce_grid_desc_mblock_mperblock, reduce_grid_desc_mblock_mperblock,
make_tuple(c_global_step[I0], c_global_step[I1])); make_multi_index(c_global_step[I0], c_global_step[I1]));
} }
}); });
} }
......
...@@ -73,16 +73,11 @@ struct Merge_v4_no_carry ...@@ -73,16 +73,11 @@ struct Merge_v4_no_carry
idx_low(Number<NDimLow - 1>{}) = tmp; idx_low(Number<NDimLow - 1>{}) = tmp;
} }
template <typename LowIdxDiff, template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_up_diff, const UpIdxDiff& idx_up_diff,
LowIdx& idx_low, LowIdx& idx_low,
const UpIdx& idx_up_new, const UpIdx& idx_up_new) const
Number<Hack>) const
{ {
static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
LowIdx::Size() == NDimLow && UpIdx::Size() == 1, LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
......
...@@ -1040,7 +1040,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -1040,7 +1040,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
c_block_copy_lds_to_global.SetSrcSliceOrigin( c_block_copy_lds_to_global.SetSrcSliceOrigin(
c_block_desc_mblock_mpershuffle_nblock_npershuffle, c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_tuple(0, 0, 0, 0)); make_multi_index(0, 0, 0, 0));
// LDS to global // LDS to global
if(is_dp_block) if(is_dp_block)
...@@ -1059,11 +1059,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -1059,11 +1059,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// constexpr offset // constexpr offset
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin( c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
c_block_desc_mblock_mpershuffle_nblock_npershuffle, c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_tuple(0, 0, 0, 0)); make_multi_index(0, 0, 0, 0));
c_block_copy_lds_to_partial_acc.SetDstSliceOrigin( c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_tuple(mxdlperwave.value, 0, nxdlperwave.value, 0)); make_multi_index(mxdlperwave.value, 0, nxdlperwave.value, 0));
c_block_copy_lds_to_partial_acc c_block_copy_lds_to_partial_acc
.template Run<decltype(c_block_buf), .template Run<decltype(c_block_buf),
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "ck/tensor_description/tensor_space_filling_curve.hpp" #include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_coordinate.hpp"
namespace ck { namespace ck {
...@@ -200,10 +201,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -200,10 +201,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetDstCoordinateResetStep(); : dst_slice_origin_step_idx + GetDstCoordinateResetStep();
// is it OK to construct a new step every time? move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step_idx);
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
} }
private: private:
...@@ -439,9 +437,8 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -439,9 +437,8 @@ struct ThreadwiseTensorSliceTransfer_v3
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
} }
template <typename SrcBuffer, typename SrcStepHacks> template <typename SrcBuffer>
__device__ void __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
{ {
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
...@@ -452,7 +449,6 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -452,7 +449,6 @@ struct ThreadwiseTensorSliceTransfer_v3
"wrong! SrcBuffer and SrcData data type are inconsistent"); "wrong! SrcBuffer and SrcData data type are inconsistent");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// scalar per access on each dim // scalar per access on each dim
// TODO: don't use lambda_scalar_per_access // TODO: don't use lambda_scalar_per_access
...@@ -472,28 +468,26 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -472,28 +468,26 @@ struct ThreadwiseTensorSliceTransfer_v3
// make forward steps // make forward steps
const auto src_forward_steps = generate_tuple( const auto src_forward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index forward_step_idx; Index forward_step;
static_for<0, nDim, 1>{}([&](auto j) { static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; forward_step(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
}); });
return make_tensor_coordinate_step( return forward_step;
src_desc, forward_step_idx, src_step_hacks[I0][i]);
}, },
Number<nDim>{}); Number<nDim>{});
// make backward steps // make backward steps
const auto src_backward_steps = generate_tuple( const auto src_backward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index backward_step_idx; Index backward_step;
static_for<0, nDim, 1>{}([&](auto j) { static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; backward_step(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
}); });
return make_tensor_coordinate_step( return backward_step;
src_desc, backward_step_idx, src_step_hacks[I1][i]);
}, },
Number<nDim>{}); Number<nDim>{});
...@@ -589,16 +583,12 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -589,16 +583,12 @@ struct ThreadwiseTensorSliceTransfer_v3
// move src coordinate back to slice origin (or not) // move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun) if constexpr(SrcResetCoordinateAfterRun)
{ {
const auto src_reset_step = move_tensor_coordinate(src_desc, src_coord_, GetSrcCoordinateResetStep());
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
} }
} }
template <typename DstBuffer, typename DstStepHacks> template <typename DstBuffer>
__device__ void __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks)
{ {
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
...@@ -609,7 +599,6 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -609,7 +599,6 @@ struct ThreadwiseTensorSliceTransfer_v3
"wrong! SrcBuffer or DstBuffer data type is wrong"); "wrong! SrcBuffer or DstBuffer data type is wrong");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// src scalar per access on each dim // src scalar per access on each dim
// TODO: don't use this // TODO: don't use this
...@@ -629,28 +618,26 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -629,28 +618,26 @@ struct ThreadwiseTensorSliceTransfer_v3
// make forward steps // make forward steps
const auto dst_forward_steps = generate_tuple( const auto dst_forward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index forward_step_idx; Index forward_step;
static_for<0, nDim, 1>{}([&](auto j) { static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; forward_step(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
}); });
return make_tensor_coordinate_step( return forward_step;
dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
}, },
Number<nDim>{}); Number<nDim>{});
// make backward steps // make backward steps
const auto dst_backward_steps = generate_tuple( const auto dst_backward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index backward_step_idx; Index backward_step;
static_for<0, nDim, 1>{}([&](auto j) { static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; backward_step(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
}); });
return make_tensor_coordinate_step( return backward_step;
dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
}, },
Number<nDim>{}); Number<nDim>{});
...@@ -749,41 +736,10 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -749,41 +736,10 @@ struct ThreadwiseTensorSliceTransfer_v3
// move dst coordinate back to slice origin (or not) // move dst coordinate back to slice origin (or not)
if constexpr(DstResetCoordinateAfterRun) if constexpr(DstResetCoordinateAfterRun)
{ {
const auto dst_reset_step = move_tensor_coordinate(dst_desc, dst_coord_, GetDstCoordinateResetStep());
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
} }
} }
template <typename SrcBuffer>
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
{
constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
constexpr auto src_step_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunRead(src_desc, src_buf, src_step_hacks);
}
template <typename DstBuffer>
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
{
constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
constexpr auto dst_step_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunWrite(dst_desc, dst_buf, dst_step_hacks);
}
__device__ static constexpr auto GetSrcCoordinateResetStep() __device__ static constexpr auto GetSrcCoordinateResetStep()
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -909,46 +865,22 @@ struct ThreadwiseTensorSliceTransfer_v3 ...@@ -909,46 +865,22 @@ struct ThreadwiseTensorSliceTransfer_v3
const Index& src_slice_origin_step_idx) const Index& src_slice_origin_step_idx)
{ {
// if src coord was not reset by RunRead(), then need to adjust the step here // if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx = const auto adjusted_step = SrcResetCoordinateAfterRun
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep(); : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step); move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
} }
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <typename SrcMoveSliceWindowStepHack>
__device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx,
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(
src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, __device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx) const Index& dst_slice_origin_step_idx)
{ {
// if dst coord was not reset by RunWrite(), then need to adjust the step here // if dst coord was not reset by RunWrite(), then need to adjust the step here
const auto adjusted_step_idx = const auto adjusted_step = DstResetCoordinateAfterRun
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx ? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetDstCoordinateResetStep(); : dst_slice_origin_step_idx + GetDstCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
} }
...@@ -1143,13 +1075,9 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1143,13 +1075,9 @@ struct ThreadwiseTensorSliceTransfer_v4
__device__ void MoveSrcSliceWindow(const SrcDesc&, __device__ void MoveSrcSliceWindow(const SrcDesc&,
const SrcSliceMoveStepIdx& src_slice_move_step_idx) const SrcSliceMoveStepIdx& src_slice_move_step_idx)
{ {
constexpr auto src_desc = SrcDesc{}; move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_idx);
const auto src_slice_move_step_iter =
make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx));
move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
} }
__device__ void SetSrcCoord(const Index& src_ref_idx) __device__ void SetSrcCoord(const Index& src_ref_idx)
{ {
src_ref_coord_ = make_tensor_coordinate(SrcDesc{}, src_ref_idx); src_ref_coord_ = make_tensor_coordinate(SrcDesc{}, src_ref_idx);
......
...@@ -42,8 +42,6 @@ struct ThreadwiseTensorSliceTransfer_v4r1 ...@@ -42,8 +42,6 @@ struct ThreadwiseTensorSliceTransfer_v4r1
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx) __device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx)
: src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
{ {
...@@ -122,12 +120,9 @@ struct ThreadwiseTensorSliceTransfer_v4r1 ...@@ -122,12 +120,9 @@ struct ThreadwiseTensorSliceTransfer_v4r1
constexpr auto src_ref_to_data_disp_idx = constexpr auto src_ref_to_data_disp_idx =
src_ref_to_origin_disp_idx + data_to_origin_disp_idx; src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
constexpr auto src_ref_to_data_disp_coord_step =
make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx);
auto src_data_coord = src_ref_coord_; auto src_data_coord = src_ref_coord_;
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step); move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_idx);
vector_type_maker_t<SrcData, src_vector_desc.GetElementSpaceSize()> src_vector; vector_type_maker_t<SrcData, src_vector_desc.GetElementSpaceSize()> src_vector;
...@@ -162,10 +157,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1 ...@@ -162,10 +157,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1
{ {
constexpr auto src_desc = SrcDesc{}; constexpr auto src_desc = SrcDesc{};
const auto src_slice_move_step_iter = move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_idx);
make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx));
move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
} }
private: private:
......
...@@ -44,9 +44,6 @@ struct ThreadwiseTensorSliceTransfer_v5r1 ...@@ -44,9 +44,6 @@ struct ThreadwiseTensorSliceTransfer_v5r1
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v5r1(const SrcDesc& src_desc, __device__ constexpr ThreadwiseTensorSliceTransfer_v5r1(const SrcDesc& src_desc,
const Index& src_slice_origin, const Index& src_slice_origin,
const DstDesc& dst_desc, const DstDesc& dst_desc,
...@@ -75,9 +72,8 @@ struct ThreadwiseTensorSliceTransfer_v5r1 ...@@ -75,9 +72,8 @@ struct ThreadwiseTensorSliceTransfer_v5r1
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
} }
template <typename SrcBuffer, typename SrcStepHacks> template <typename SrcBuffer>
__device__ void __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
{ {
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
...@@ -113,28 +109,26 @@ struct ThreadwiseTensorSliceTransfer_v5r1 ...@@ -113,28 +109,26 @@ struct ThreadwiseTensorSliceTransfer_v5r1
// make forward steps // make forward steps
const auto src_forward_steps = generate_tuple( const auto src_forward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index forward_step_idx; Index forward_step;
static_for<0, nDim, 1>{}([&](auto j) { static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0; forward_step(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0;
}); });
return make_tensor_coordinate_step( return forward_step;
src_desc, forward_step_idx, src_step_hacks[I0][i]);
}, },
Number<nDim>{}); Number<nDim>{});
// make backward steps // make backward steps
const auto src_backward_steps = generate_tuple( const auto src_backward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index backward_step_idx; Index backward_step;
static_for<0, nDim, 1>{}([&](auto j) { static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0; backward_step(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0;
}); });
return make_tensor_coordinate_step( return backward_step;
src_desc, backward_step_idx, src_step_hacks[I1][i]);
}, },
Number<nDim>{}); Number<nDim>{});
...@@ -236,16 +230,12 @@ struct ThreadwiseTensorSliceTransfer_v5r1 ...@@ -236,16 +230,12 @@ struct ThreadwiseTensorSliceTransfer_v5r1
// move src coordinate back to slice origin (or not) // move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun) if constexpr(SrcResetCoordinateAfterRun)
{ {
const auto src_reset_step = move_tensor_coordinate(src_desc, src_coord_, GetSrcCoordinateResetStep());
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
} }
} }
template <typename DstBuffer, typename DstStepHacks> template <typename DstBuffer>
__device__ void __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks)
{ {
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
...@@ -281,28 +271,26 @@ struct ThreadwiseTensorSliceTransfer_v5r1 ...@@ -281,28 +271,26 @@ struct ThreadwiseTensorSliceTransfer_v5r1
// make forward steps // make forward steps
const auto dst_forward_steps = generate_tuple( const auto dst_forward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index forward_step_idx; Index forward_step;
static_for<0, nDim, 1>{}([&](auto j) { static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0; forward_step(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0;
}); });
return make_tensor_coordinate_step( return forward_step;
dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
}, },
Number<nDim>{}); Number<nDim>{});
// make backward steps // make backward steps
const auto dst_backward_steps = generate_tuple( const auto dst_backward_steps = generate_tuple(
[&](auto i) { [&](auto i) {
Index backward_step_idx; Index backward_step;
static_for<0, nDim, 1>{}([&](auto j) { static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0; backward_step(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0;
}); });
return make_tensor_coordinate_step( return backward_step;
dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
}, },
Number<nDim>{}); Number<nDim>{});
...@@ -406,41 +394,10 @@ struct ThreadwiseTensorSliceTransfer_v5r1 ...@@ -406,41 +394,10 @@ struct ThreadwiseTensorSliceTransfer_v5r1
// move dst coordinate back to slice origin (or not) // move dst coordinate back to slice origin (or not)
if constexpr(DstResetCoordinateAfterRun) if constexpr(DstResetCoordinateAfterRun)
{ {
const auto dst_reset_step = move_tensor_coordinate(dst_desc, dst_coord_, GetDstCoordinateResetStep());
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
} }
} }
template <typename SrcBuffer>
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
{
constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
constexpr auto src_step_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunRead(src_desc, src_buf, src_step_hacks);
}
template <typename DstBuffer>
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
{
constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
constexpr auto dst_step_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunWrite(dst_desc, dst_buf, dst_step_hacks);
}
__device__ static constexpr auto GetSrcCoordinateResetStep() __device__ static constexpr auto GetSrcCoordinateResetStep()
{ {
constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{};
...@@ -556,46 +513,22 @@ struct ThreadwiseTensorSliceTransfer_v5r1 ...@@ -556,46 +513,22 @@ struct ThreadwiseTensorSliceTransfer_v5r1
const Index& src_slice_origin_step_idx) const Index& src_slice_origin_step_idx)
{ {
// if src coord was not reset by RunRead(), then need to adjust the step here // if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx = const auto adjusted_step = SrcResetCoordinateAfterRun
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep(); : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step); move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
} }
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <typename SrcMoveSliceWindowStepHack>
__device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx,
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(
src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, __device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx) const Index& dst_slice_origin_step_idx)
{ {
// if dst coord was not reset by RunWrite(), then need to adjust the step here // if dst coord was not reset by RunWrite(), then need to adjust the step here
const auto adjusted_step_idx = const auto adjusted_step = DstResetCoordinateAfterRun
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx ? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetDstCoordinateResetStep(); : dst_slice_origin_step_idx + GetDstCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment