Commit c9b52bce authored by Chao Liu's avatar Chao Liu
Browse files

threadwise copy v1r3 take (compile-time) SrcSliceOriginIdx as argument

parent dea22d0e
......@@ -417,7 +417,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
m_thread_data_on_global % M1,
n_thread_data_on_global / N1,
n_thread_data_on_global % N1))
.Run(p_c_thread,
.Run(c_m0_m1_n0_n1_thread_desc,
make_tuple(I0, I0, I0, I0),
p_c_thread,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_m0_m1_n0_n1_global_tensor_iterator_hacks);
......
......@@ -69,17 +69,27 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
dst_slice_origin_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx);
}
template <typename DstIteratorHacks>
__device__ void Run(const SrcData* p_src,
template <typename SrcSliceOriginIdx, typename DstIteratorHacks>
__device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&,
const SrcData* p_src,
const DstDesc& dst_desc,
DstData* p_dst,
const DstIteratorHacks& dst_iterator_hacks)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<SrcSliceOriginIdx>>>::value,
"wrong! SrcSliceOrigin need to known at compile-time");
// Comments: src_desc is constexpr
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{};
constexpr auto src_slice_origin_idx = SrcSliceOriginIdx{};
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
......@@ -171,12 +181,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
vector_type<DstData, DstScalarPerVector> dst_vector;
// this is hardcoded for src that has compile-time tensor descriptor
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
// assume src_slice_origin_idx is 0
// TODO: support non-zero src_slice_oring_idx
constexpr index_t src_offset =
src_desc.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector);
src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx +
i * dst_scalar_step_in_vector);
dst_vector(i) = p_src[Number<src_offset>{}];
});
......@@ -370,17 +378,27 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
src_slice_origin_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx);
}
template <typename SrcIteratorHacks>
template <typename DstSliceOriginIdx, typename SrcIteratorHacks>
__device__ void Run(const SrcDesc& src_desc,
const SrcData* p_src,
const DstDesc&,
const DstSliceOriginIdx&,
DstData* p_dst,
const SrcIteratorHacks& src_iterator_hacks)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
static_assert(DstDesc::IsKnownAtCompileTime(),
"wrong! DstDesc need to known at compile-time");
// Comments: dst_desc is constexpr
static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<DstSliceOriginIdx>>>::value,
"wrong! DstSliceOrigin need to known at compile-time");
// DstDesc and dst_slice_origin_idx are known at compile-time
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{};
constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{};
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
......@@ -493,12 +511,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
: src_vector_t{0};
}
// this is hardcoded for dst that has compile-time tensor descriptor
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
// assume dst_slice_origin_idx is 0
// TODO: support non-zero dst_slice_oring_idx
constexpr index_t dst_offset =
dst_desc.CalculateOffset(src_data_idx + i * src_scalar_step_in_vector);
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
i * src_scalar_step_in_vector);
p_dst[Number<dst_offset>{}] = src_vector[i];
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment