Commit e8421cca authored by Chao Liu's avatar Chao Liu
Browse files

replacing array with vector for tensor data

parent 4978c9e7
......@@ -364,6 +364,9 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename BlockMatrixA,
typename BlockMatrixB,
typename ThreadMatrixC,
......@@ -375,7 +378,11 @@ template <index_t BlockSize,
index_t MLevel1ThreadCluster,
index_t NLevel1ThreadCluster,
index_t ThreadGemmADataPerRead_M,
index_t ThreadGemmBDataPerRead_N>
index_t ThreadGemmBDataPerRead_N,
typename std::enable_if<BlockMatrixA::IsKnownAtCompileTime() &&
BlockMatrixB::IsKnownAtCompileTime() &&
ThreadMatrixC::IsKnownAtCompileTime(),
bool>::type = false>
struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
{
struct MatrixIndex
......@@ -384,10 +391,49 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
index_t col;
};
index_t mMyThreadOffsetA;
index_t mMyThreadOffsetB;
private:
static constexpr auto a_thread_mtx_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, ThreadMatrixC{}.GetLength(Number<0>{})));
static constexpr auto b_thread_mtx_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, ThreadMatrixC{}.GetLength(Number<1>{})));
using AThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
FloatA,
BlockMatrixA,
decltype(a_thread_mtx_desc_),
Sequence<KPerThreadLoop, MPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmADataPerRead_M,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>;
using BThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
FloatB,
BlockMatrixB,
decltype(b_thread_mtx_desc_),
Sequence<KPerThreadLoop, NPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmBDataPerRead_N,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>;
MatrixIndex c_thread_begin_mtx_idx_;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
public:
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v1r1()
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.row)},
b_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.col)}
{
static_assert(BlockMatrixA::IsKnownAtCompileTime() &&
BlockMatrixB::IsKnownAtCompileTime() &&
......@@ -403,23 +449,18 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0),
"wrong! K dimension not consistent\n");
"wrong! K dimension not consistent");
constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed
constexpr index_t N = BlockMatrixB{}.GetLength(I1);
static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 &&
N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0,
"wrong! Cannot evenly divide work among\n");
"wrong! Cannot evenly divide work among");
static_assert(ThreadMatrixC{}.GetLength(I0) == GetThreadMatrixCLengths()[I0] &&
ThreadMatrixC{}.GetLength(I1) == GetThreadMatrixCLengths()[I1],
"wrong! ThreadMatrixC lengths is wrong");
auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
mMyThreadOffsetA = BlockMatrixA{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.row));
mMyThreadOffsetB = BlockMatrixB{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.col));
}
__device__ static constexpr auto GetThreadMatrixCLengths()
......@@ -456,21 +497,20 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void
Run_pipelined_2x2(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx_desc = ThreadMatrixC{};
constexpr auto K = a_block_mtx.GetLength(I0);
constexpr auto MPerThread = c_thread_mtx.GetLength(I0);
constexpr auto NPerThread = c_thread_mtx.GetLength(I1);
constexpr auto MPerThread = c_thread_mtx_desc.GetLength(I0);
constexpr auto NPerThread = c_thread_mtx_desc.GetLength(I1);
constexpr index_t MPerLevel1Cluster =
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
......@@ -484,13 +524,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
static_assert(MRepeat == 2 && NRepeat == 2,
"wrong! inline asm cannot deal with this GEMM config yet");
// thread A, B
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<MPerThread>{}));
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<NPerThread>{}));
// thread A-sub, B-sub
constexpr auto a_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}),
......@@ -504,73 +537,44 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}),
make_tuple(Number<NPerThread>{}, Number<1>{}));
FloatA p_a_thread[a_thread_mtx.GetElementSpaceSize()];
FloatB p_b_thread[b_thread_mtx.GetElementSpaceSize()];
constexpr auto a_thread_copy =
ThreadwiseDynamicTensorSliceTransfer_v2<FloatA,
FloatA,
BlockMatrixA,
decltype(a_thread_mtx),
Sequence<KPerThreadLoop, MPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmADataPerRead_M,
AddressSpace::Generic,
AddressSpace::Vgpr,
1,
true>{BlockMatrixA{}, make_tuple()};
constexpr auto b_thread_copy =
ThreadwiseDynamicTensorSliceTransfer_v2<FloatB,
FloatB,
BlockMatrixB,
decltype(b_thread_mtx),
Sequence<KPerThreadLoop, NPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmBDataPerRead_N,
AddressSpace::Generic,
AddressSpace::Vgpr,
1,
true>{};
FloatA p_a_thread[a_thread_mtx_desc_.GetElementSpaceSize()];
FloatB p_b_thread[b_thread_mtx_desc_.GetElementSpaceSize()];
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_sub_mtx),
decltype(b_thread_sub_mtx),
decltype(c_thread_sub_mtx)>{};
const FloatA* p_a_block_off = p_a_block + mMyThreadOffsetA;
const FloatB* p_b_block_off = p_b_block + mMyThreadOffsetB;
// read A_sub_0
a_thread_copy.Run(BlockMatrixA{},
p_a_block_off,
a_thread_mtx,
make_tuple(Number<0>{}, Number<0>{}),
p_a_thread);
a_thread_copy_.Run(BlockMatrixA{},
make_tuple(Number<0>{}, Number<0>{}),
p_a_block,
a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<0>{}),
p_a_thread);
// read B_sub_0
b_thread_copy.Run(BlockMatrixB{},
p_b_block_off,
b_thread_mtx,
make_tuple(Number<0>{}, Number<0>{}),
p_b_thread);
b_thread_copy_.Run(BlockMatrixB{},
make_tuple(Number<0>{}, Number<0>{}),
p_b_block,
b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<0>{}),
p_b_thread);
// read B_sub_1
b_thread_copy.Run(BlockMatrixB{},
p_b_block_off +
b_block_mtx.CalculateOffset(make_tuple(0, NPerLevel1Cluster)),
b_thread_mtx,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
p_b_thread);
b_thread_copy_.Run(BlockMatrixB{},
make_tuple(Number<0>{}, Number<NPerLevel1Cluster>{}),
p_b_block,
b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
p_b_thread);
// read A_sub_1
a_thread_copy.Run(BlockMatrixA{},
p_a_block_off +
a_block_mtx.CalculateOffset(make_tuple(0, MPerLevel1Cluster)),
a_thread_mtx,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
p_a_thread);
a_thread_copy_.Run(BlockMatrixA{},
make_tuple(Number<0>{}, Number<MPerLevel1Cluster>{}),
p_a_block,
a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
p_a_thread);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
......@@ -578,53 +582,55 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(
p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
p_b_thread + b_thread_mtx_desc_.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx_desc.CalculateOffset(make_tuple(0, NPerThreadSubC)));
// loop over rest of k
static_for<KPerThreadLoop, K, KPerThreadLoop>{}([&](auto k) {
// read A_sub_0
a_thread_copy.Run(BlockMatrixA{},
p_a_block_off + a_block_mtx.CalculateOffset(make_tuple(k, 0)),
a_thread_mtx,
make_tuple(Number<0>{}, Number<0>{}),
p_a_thread);
a_thread_copy_.Run(BlockMatrixA{},
make_tuple(k, Number<0>{}),
p_a_block,
a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<0>{}),
p_a_thread);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_a_thread + a_thread_mtx_desc_.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread,
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, 0)));
p_c_thread + c_thread_mtx_desc.CalculateOffset(make_tuple(MPerThreadSubC, 0)));
// read B_sub_0
b_thread_copy.Run(BlockMatrixB{},
p_b_block_off + b_block_mtx.CalculateOffset(make_tuple(k, 0)),
b_thread_mtx,
make_tuple(Number<0>{}, Number<0>{}),
p_b_thread);
b_thread_copy_.Run(BlockMatrixB{},
make_tuple(k, Number<0>{}),
p_b_block,
b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<0>{}),
p_b_thread);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_a_thread + a_thread_mtx_desc_.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread + b_thread_mtx_desc_.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread +
c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC)));
c_thread_mtx_desc.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC)));
// read B_sub_1
b_thread_copy.Run(BlockMatrixB{},
p_b_block_off +
b_block_mtx.CalculateOffset(make_tuple(k, NPerLevel1Cluster)),
b_thread_mtx,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
p_b_thread);
b_thread_copy_.Run(BlockMatrixB{},
make_tuple(k, Number<NPerLevel1Cluster>{}),
p_b_block,
b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
p_b_thread);
// read A_sub_1
a_thread_copy.Run(BlockMatrixA{},
p_a_block_off +
a_block_mtx.CalculateOffset(make_tuple(k, MPerLevel1Cluster)),
a_thread_mtx,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
p_a_thread);
a_thread_copy_.Run(BlockMatrixA{},
make_tuple(k, Number<MPerLevel1Cluster>{}),
p_a_block,
a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
p_a_thread);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
......@@ -632,24 +638,24 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(
p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
p_b_thread + b_thread_mtx_desc_.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx_desc.CalculateOffset(make_tuple(0, NPerThreadSubC)));
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_a_thread + a_thread_mtx_desc_.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread,
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, 0)));
p_c_thread + c_thread_mtx_desc.CalculateOffset(make_tuple(MPerThreadSubC, 0)));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC)));
p_a_thread + a_thread_mtx_desc_.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread + b_thread_mtx_desc_.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread +
c_thread_mtx_desc.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC)));
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
......
......@@ -255,6 +255,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
constexpr auto c_m0m1_n0n1_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat * MPerThread>{}, Number<NRepeat * NPerThread>{}));
#if 1 // debug
const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v1<BlockSize,
decltype(a_k_m_block_desc),
......@@ -269,6 +270,26 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
NLevel1Cluster,
MPerThread,
NPerThread>{};
#else
const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v1r1<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k_m_block_desc),
decltype(b_k_n_block_desc),
decltype(c_m0m1_n0n1_thread_desc),
MPerThread,
NPerThread,
KPerThread,
MLevel0Cluster,
NLevel0Cluster,
MLevel1Cluster,
NLevel1Cluster,
MPerThread,
NPerThread>{};
#endif
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
......
......@@ -1330,6 +1330,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// 2. a reference dst_reference_idx is given at compile-time, dst_slice_origin_idx has a
// compile-time distance to dst_reference_idx
// 3. use direct address calculation (lower of coordinate)
// 3. vector access on src
template <
typename SrcData,
typename DstData,
......@@ -1355,34 +1356,98 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{}));
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v4(const Index& src_ref_idx)
: src_ref_idx_(make_dynamic_tensor_coordinate(SrcDesc{}, src_ref_idx))
: src_ref_coord_(make_dynamic_tensor_coordinate(SrcDesc{}, src_ref_idx))
{
static_assert(SrcDesc::IsKnownAtCompileTime && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time");
}
template <typename SrcOriginToRefDistance, typename DstOriginToRefDistance>
__device__ void Run(const SrcDesc& const SrcOriginToRefDistance& const SrcData* p_src,
template <typename SrcRefToOriginDisplacement, typename DstRefToOriginDisplacement>
__device__ void Run(const SrcDesc&,
const SrcRefToOriginDisplacement&,
const SrcData* p_src,
const DstDesc&,
const DstOriginToRefDistance& DstData* p_dst)
const DstRefToOriginDisplacement&,
DstData* p_dst) const
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time");
static_assert(is_known_at_compile_time<
remove_cv_t<remove_reference_t<SrcRefToOriginDisplacement>>>::value &&
is_known_at_compile_time<
remove_cv_t<remove_reference_t<DstRefToOriginDisplacement>>>::value,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time");
// SrcDesc and DstDesc are known at compile-time
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{};
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{};
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr auto src_ref_to_origin_disp_idx = SrcRefToOriginDisplacement{};
constexpr auto dst_ref_to_origin_disp_idx = DstRefToOriginDisplacement{};
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// scalar per access of each dim
constexpr auto src_scalar_per_access = generate_sequence_v2(
[&](auto i) constexpr {
if constexpr(i == SrcVectorDim)
{
return Number<SrcScalarPerVector>{};
}
else
{
return Number<1>{};
}
},
Number<nDim>{});
// scalar step (if steping on SrcVectorDim) of each dim
constexpr auto src_scalar_step_in_vector = generate_sequence_v2(
[&](auto i) constexpr {
if constexpr(i == SrcVectorDim)
{
return Number<1>{};
}
else
{
return Number<0>{};
}
},
Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto dim_access_order = DimAccessOrder{};
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// position in slice window
constexpr auto data_to_origin_dist_idx =
// position in slice window
#if 0 // debug
constexpr auto data_to_origin_disp_idx =
container_reorder_given_old2new(ordered_access_idx, dim_access_order) *
src_scalar_per_access;
#else
constexpr auto data_to_origin_disp_idx =
ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access;
#endif
// src coordinate
constexpr auto src_data_to_ref_dist_idx =
SrcOriginToRefDistance{} + data_to_origin_dist_idx;
constexpr auto src_ref_to_data_disp_idx =
to_multi_index(src_ref_to_origin_disp_idx + data_to_origin_disp_idx);
constexpr auto src_data_to_ref_dist_coord_iterator =
make_dynamic_tensor_coordinate_iterator(SrcDesc{}, src_data_to_ref_dist);
constexpr auto src_ref_to_data_disp_coord_iterator =
make_dynamic_tensor_coordinate_iterator(src_desc, src_ref_to_data_disp_idx);
auto src_data_coord = src_ref_coord_;
move_dynamic_tensor_coordinate(
src_data_coord, src_data_coord, src_data_to_ref_coord_iterator);
src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator);
// copy data from src into buffer
StaticBuffer<SrcData, SrcScalarPerVector> src_buf;
......@@ -1391,20 +1456,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_coord_);
src_desc, src_data_coord);
src_buf.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid ? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
is_src_valid
? *reinterpret_cast<const src_vector_t*>(&p_src[src_data_coord.GetOffset()])
: src_vector_t{0};
// copy data from buffer into dst
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr auto src_scalar_step_in_vector = generate_sequence(
detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
constexpr index_t dst_offset = dst_desc.CalculateOffset(
to_multi_index(DstOriginToRefDistance{}) + data_to_origin_dist_idx +
to_multi_index(dst_ref_to_origin_disp_idx) + data_to_origin_disp_idx +
i * src_scalar_step_in_vector);
p_dst[Number<dst_offset>{}] = src_buf.template AsType<SrcData>()[i];
......@@ -1413,7 +1475,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
}
private:
SrcCoord src_ref_idx_;
SrcCoord src_ref_coord_;
};
} // namespace ck
......
......@@ -5,11 +5,26 @@
namespace ck {
template <index_t... Is>
__host__ __device__ constexpr auto make_sequence(Number<Is>...)
{
return Sequence<Is...>{};
}
// F returns index_t
template <typename F, index_t N>
__host__ __device__ constexpr auto generate_sequence(F, Number<N>)
{
return typename sequence_gen<N, F>::type{};
}
// F returns Number<>
template <typename F, index_t N>
__host__ __device__ constexpr auto generate_sequence_v2(F&& f, Number<N>)
{
return unpack([&f](auto&&... xs) { return make_sequence(f(xs)...); },
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment