Commit c9b52bce authored by Chao Liu's avatar Chao Liu
Browse files

threadwise copy v1r3 take (compile-time) SrcSliceOriginIdx as argument

parent dea22d0e
...@@ -417,7 +417,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -417,7 +417,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
m_thread_data_on_global % M1, m_thread_data_on_global % M1,
n_thread_data_on_global / N1, n_thread_data_on_global / N1,
n_thread_data_on_global % N1)) n_thread_data_on_global % N1))
.Run(p_c_thread, .Run(c_m0_m1_n0_n1_thread_desc,
make_tuple(I0, I0, I0, I0),
p_c_thread,
c_m0_m1_n0_n1_global_desc, c_m0_m1_n0_n1_global_desc,
p_c_global, p_c_global,
c_m0_m1_n0_n1_global_tensor_iterator_hacks); c_m0_m1_n0_n1_global_tensor_iterator_hacks);
......
...@@ -69,18 +69,28 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -69,18 +69,28 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
dst_slice_origin_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx); dst_slice_origin_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx);
} }
template <typename DstIteratorHacks> template <typename SrcSliceOriginIdx, typename DstIteratorHacks>
__device__ void Run(const SrcData* p_src, __device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&,
const SrcData* p_src,
const DstDesc& dst_desc, const DstDesc& dst_desc,
DstData* p_dst, DstData* p_dst,
const DstIteratorHacks& dst_iterator_hacks) const DstIteratorHacks& dst_iterator_hacks)
{ {
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<SrcSliceOriginIdx>>>::value,
"wrong! SrcSliceOrigin need to known at compile-time");
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{};
constexpr auto src_slice_origin_idx = SrcSliceOriginIdx{};
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
// Comments: src_desc is constexpr
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{};
// scalar per access on each dim // scalar per access on each dim
// TODO: don't use lambda_scalar_per_access // TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence( constexpr auto dst_scalar_per_access = generate_sequence(
...@@ -171,12 +181,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -171,12 +181,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
vector_type<DstData, DstScalarPerVector> dst_vector; vector_type<DstData, DstScalarPerVector> dst_vector;
// this is hardcoded for src that has compile-time tensor descriptor
static_for<0, DstScalarPerVector, 1>{}([&](auto i) { static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
// assume src_slice_origin_idx is 0
// TODO: support non-zero src_slice_oring_idx
constexpr index_t src_offset = constexpr index_t src_offset =
src_desc.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector); src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx +
i * dst_scalar_step_in_vector);
dst_vector(i) = p_src[Number<src_offset>{}]; dst_vector(i) = p_src[Number<src_offset>{}];
}); });
...@@ -370,18 +378,28 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -370,18 +378,28 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
src_slice_origin_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx); src_slice_origin_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx);
} }
template <typename SrcIteratorHacks> template <typename DstSliceOriginIdx, typename SrcIteratorHacks>
__device__ void Run(const SrcDesc& src_desc, __device__ void Run(const SrcDesc& src_desc,
const SrcData* p_src, const SrcData* p_src,
const DstDesc&,
const DstSliceOriginIdx&,
DstData* p_dst, DstData* p_dst,
const SrcIteratorHacks& src_iterator_hacks) const SrcIteratorHacks& src_iterator_hacks)
{ {
static_assert(DstDesc::IsKnownAtCompileTime(),
"wrong! DstDesc need to known at compile-time");
static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<DstSliceOriginIdx>>>::value,
"wrong! DstSliceOrigin need to known at compile-time");
// DstDesc and dst_slice_origin_idx are known at compile-time
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{};
constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{};
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
// Comments: dst_desc is constexpr
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{};
// scalar per access on each dim // scalar per access on each dim
// TODO: don't use lambda_scalar_per_access // TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence( constexpr auto src_scalar_per_access = generate_sequence(
...@@ -493,12 +511,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -493,12 +511,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
: src_vector_t{0}; : src_vector_t{0};
} }
// this is hardcoded for dst that has compile-time tensor descriptor
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
// assume dst_slice_origin_idx is 0
// TODO: support non-zero dst_slice_oring_idx
constexpr index_t dst_offset = constexpr index_t dst_offset =
dst_desc.CalculateOffset(src_data_idx + i * src_scalar_step_in_vector); dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
i * src_scalar_step_in_vector);
p_dst[Number<dst_offset>{}] = src_vector[i]; p_dst[Number<dst_offset>{}] = src_vector[i];
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment