Commit e8421cca authored by Chao Liu's avatar Chao Liu
Browse files

replacing array with vector for tensor data

parent 4978c9e7
...@@ -364,6 +364,9 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1 ...@@ -364,6 +364,9 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster, // MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster // MLevel1ThreadCluster, NLevel1ThreadCluster
template <index_t BlockSize, template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename BlockMatrixA, typename BlockMatrixA,
typename BlockMatrixB, typename BlockMatrixB,
typename ThreadMatrixC, typename ThreadMatrixC,
...@@ -375,7 +378,11 @@ template <index_t BlockSize, ...@@ -375,7 +378,11 @@ template <index_t BlockSize,
index_t MLevel1ThreadCluster, index_t MLevel1ThreadCluster,
index_t NLevel1ThreadCluster, index_t NLevel1ThreadCluster,
index_t ThreadGemmADataPerRead_M, index_t ThreadGemmADataPerRead_M,
index_t ThreadGemmBDataPerRead_N> index_t ThreadGemmBDataPerRead_N,
typename std::enable_if<BlockMatrixA::IsKnownAtCompileTime() &&
BlockMatrixB::IsKnownAtCompileTime() &&
ThreadMatrixC::IsKnownAtCompileTime(),
bool>::type = false>
struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
{ {
struct MatrixIndex struct MatrixIndex
...@@ -384,10 +391,49 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -384,10 +391,49 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
index_t col; index_t col;
}; };
index_t mMyThreadOffsetA; private:
index_t mMyThreadOffsetB; static constexpr auto a_thread_mtx_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, ThreadMatrixC{}.GetLength(Number<0>{})));
static constexpr auto b_thread_mtx_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, ThreadMatrixC{}.GetLength(Number<1>{})));
using AThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
FloatA,
BlockMatrixA,
decltype(a_thread_mtx_desc_),
Sequence<KPerThreadLoop, MPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmADataPerRead_M,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>;
using BThreadCopy =
ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
FloatB,
BlockMatrixB,
decltype(b_thread_mtx_desc_),
Sequence<KPerThreadLoop, NPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmBDataPerRead_N,
AddressSpace::Generic,
AddressSpace::Vgpr,
1>;
MatrixIndex c_thread_begin_mtx_idx_;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
public:
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v1r1() __device__ BlockwiseGemm_km_kn_m0m1n0n1_v1r1()
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.row)},
b_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.col)}
{ {
static_assert(BlockMatrixA::IsKnownAtCompileTime() && static_assert(BlockMatrixA::IsKnownAtCompileTime() &&
BlockMatrixB::IsKnownAtCompileTime() && BlockMatrixB::IsKnownAtCompileTime() &&
...@@ -403,23 +449,18 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -403,23 +449,18 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n"); static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0), static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0),
"wrong! K dimension not consistent\n"); "wrong! K dimension not consistent");
constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed
constexpr index_t N = BlockMatrixB{}.GetLength(I1); constexpr index_t N = BlockMatrixB{}.GetLength(I1);
static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 && static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 &&
N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0, N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0,
"wrong! Cannot evenly divide work among\n"); "wrong! Cannot evenly divide work among");
static_assert(ThreadMatrixC{}.GetLength(I0) == GetThreadMatrixCLengths()[I0] && static_assert(ThreadMatrixC{}.GetLength(I0) == GetThreadMatrixCLengths()[I0] &&
ThreadMatrixC{}.GetLength(I1) == GetThreadMatrixCLengths()[I1], ThreadMatrixC{}.GetLength(I1) == GetThreadMatrixCLengths()[I1],
"wrong! ThreadMatrixC lengths is wrong"); "wrong! ThreadMatrixC lengths is wrong");
auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
mMyThreadOffsetA = BlockMatrixA{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.row));
mMyThreadOffsetB = BlockMatrixB{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.col));
} }
__device__ static constexpr auto GetThreadMatrixCLengths() __device__ static constexpr auto GetThreadMatrixCLengths()
...@@ -456,21 +497,20 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -456,21 +497,20 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC}; level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
} }
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void __device__ void
Run_pipelined_2x2(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const Run_pipelined_2x2(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto a_block_mtx = BlockMatrixA{}; constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{}; constexpr auto c_thread_mtx_desc = ThreadMatrixC{};
constexpr auto K = a_block_mtx.GetLength(I0); constexpr auto K = a_block_mtx.GetLength(I0);
constexpr auto MPerThread = c_thread_mtx.GetLength(I0); constexpr auto MPerThread = c_thread_mtx_desc.GetLength(I0);
constexpr auto NPerThread = c_thread_mtx.GetLength(I1); constexpr auto NPerThread = c_thread_mtx_desc.GetLength(I1);
constexpr index_t MPerLevel1Cluster = constexpr index_t MPerLevel1Cluster =
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster; MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
...@@ -484,13 +524,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -484,13 +524,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
static_assert(MRepeat == 2 && NRepeat == 2, static_assert(MRepeat == 2 && NRepeat == 2,
"wrong! inline asm cannot deal with this GEMM config yet"); "wrong! inline asm cannot deal with this GEMM config yet");
// thread A, B
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<MPerThread>{}));
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<NPerThread>{}));
// thread A-sub, B-sub // thread A-sub, B-sub
constexpr auto a_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2( constexpr auto a_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}), make_tuple(Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}),
...@@ -504,73 +537,44 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -504,73 +537,44 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}), make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}),
make_tuple(Number<NPerThread>{}, Number<1>{})); make_tuple(Number<NPerThread>{}, Number<1>{}));
FloatA p_a_thread[a_thread_mtx.GetElementSpaceSize()]; FloatA p_a_thread[a_thread_mtx_desc_.GetElementSpaceSize()];
FloatB p_b_thread[b_thread_mtx.GetElementSpaceSize()]; FloatB p_b_thread[b_thread_mtx_desc_.GetElementSpaceSize()];
constexpr auto a_thread_copy =
ThreadwiseDynamicTensorSliceTransfer_v2<FloatA,
FloatA,
BlockMatrixA,
decltype(a_thread_mtx),
Sequence<KPerThreadLoop, MPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmADataPerRead_M,
AddressSpace::Generic,
AddressSpace::Vgpr,
1,
true>{BlockMatrixA{}, make_tuple()};
constexpr auto b_thread_copy =
ThreadwiseDynamicTensorSliceTransfer_v2<FloatB,
FloatB,
BlockMatrixB,
decltype(b_thread_mtx),
Sequence<KPerThreadLoop, NPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmBDataPerRead_N,
AddressSpace::Generic,
AddressSpace::Vgpr,
1,
true>{};
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_sub_mtx), constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_sub_mtx),
decltype(b_thread_sub_mtx), decltype(b_thread_sub_mtx),
decltype(c_thread_sub_mtx)>{}; decltype(c_thread_sub_mtx)>{};
const FloatA* p_a_block_off = p_a_block + mMyThreadOffsetA;
const FloatB* p_b_block_off = p_b_block + mMyThreadOffsetB;
// read A_sub_0 // read A_sub_0
a_thread_copy.Run(BlockMatrixA{}, a_thread_copy_.Run(BlockMatrixA{},
p_a_block_off, make_tuple(Number<0>{}, Number<0>{}),
a_thread_mtx, p_a_block,
make_tuple(Number<0>{}, Number<0>{}), a_thread_mtx_desc_,
p_a_thread); make_tuple(Number<0>{}, Number<0>{}),
p_a_thread);
// read B_sub_0 // read B_sub_0
b_thread_copy.Run(BlockMatrixB{}, b_thread_copy_.Run(BlockMatrixB{},
p_b_block_off, make_tuple(Number<0>{}, Number<0>{}),
b_thread_mtx, p_b_block,
make_tuple(Number<0>{}, Number<0>{}), b_thread_mtx_desc_,
p_b_thread); make_tuple(Number<0>{}, Number<0>{}),
p_b_thread);
// read B_sub_1 // read B_sub_1
b_thread_copy.Run(BlockMatrixB{}, b_thread_copy_.Run(BlockMatrixB{},
p_b_block_off + make_tuple(Number<0>{}, Number<NPerLevel1Cluster>{}),
b_block_mtx.CalculateOffset(make_tuple(0, NPerLevel1Cluster)), p_b_block,
b_thread_mtx, b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}), make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
p_b_thread); p_b_thread);
// read A_sub_1 // read A_sub_1
a_thread_copy.Run(BlockMatrixA{}, a_thread_copy_.Run(BlockMatrixA{},
p_a_block_off + make_tuple(Number<0>{}, Number<MPerLevel1Cluster>{}),
a_block_mtx.CalculateOffset(make_tuple(0, MPerLevel1Cluster)), p_a_block,
a_thread_mtx, a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}), make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
p_a_thread); p_a_thread);
// C_sub_00 += transpose(A_sub_0) * B_sub_0 // C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread); threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
...@@ -578,53 +582,55 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -578,53 +582,55 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// C_sub_01 += transpose(A_sub_0) * B_sub_1 // C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run( threadwise_gemm.Run(
p_a_thread, p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)), p_b_thread + b_thread_mtx_desc_.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC))); p_c_thread + c_thread_mtx_desc.CalculateOffset(make_tuple(0, NPerThreadSubC)));
// loop over rest of k // loop over rest of k
static_for<KPerThreadLoop, K, KPerThreadLoop>{}([&](auto k) { static_for<KPerThreadLoop, K, KPerThreadLoop>{}([&](auto k) {
// read A_sub_0 // read A_sub_0
a_thread_copy.Run(BlockMatrixA{}, a_thread_copy_.Run(BlockMatrixA{},
p_a_block_off + a_block_mtx.CalculateOffset(make_tuple(k, 0)), make_tuple(k, Number<0>{}),
a_thread_mtx, p_a_block,
make_tuple(Number<0>{}, Number<0>{}), a_thread_mtx_desc_,
p_a_thread); make_tuple(Number<0>{}, Number<0>{}),
p_a_thread);
// C_sub_10 += transpose(A_sub_1) * B_sub_0 // C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run( threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)), p_a_thread + a_thread_mtx_desc_.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread, p_b_thread,
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, 0))); p_c_thread + c_thread_mtx_desc.CalculateOffset(make_tuple(MPerThreadSubC, 0)));
// read B_sub_0 // read B_sub_0
b_thread_copy.Run(BlockMatrixB{}, b_thread_copy_.Run(BlockMatrixB{},
p_b_block_off + b_block_mtx.CalculateOffset(make_tuple(k, 0)), make_tuple(k, Number<0>{}),
b_thread_mtx, p_b_block,
make_tuple(Number<0>{}, Number<0>{}), b_thread_mtx_desc_,
p_b_thread); make_tuple(Number<0>{}, Number<0>{}),
p_b_thread);
// C_sub_11 += transpose(A_sub_1) * B_sub_1 // C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run( threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)), p_a_thread + a_thread_mtx_desc_.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)), p_b_thread + b_thread_mtx_desc_.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + p_c_thread +
c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC))); c_thread_mtx_desc.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC)));
// read B_sub_1 // read B_sub_1
b_thread_copy.Run(BlockMatrixB{}, b_thread_copy_.Run(BlockMatrixB{},
p_b_block_off + make_tuple(k, Number<NPerLevel1Cluster>{}),
b_block_mtx.CalculateOffset(make_tuple(k, NPerLevel1Cluster)), p_b_block,
b_thread_mtx, b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}), make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
p_b_thread); p_b_thread);
// read A_sub_1 // read A_sub_1
a_thread_copy.Run(BlockMatrixA{}, a_thread_copy_.Run(BlockMatrixA{},
p_a_block_off + make_tuple(k, Number<MPerLevel1Cluster>{}),
a_block_mtx.CalculateOffset(make_tuple(k, MPerLevel1Cluster)), p_a_block,
a_thread_mtx, a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}), make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
p_a_thread); p_a_thread);
// C_sub_00 += transpose(A_sub_0) * B_sub_0 // C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread); threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
...@@ -632,24 +638,24 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -632,24 +638,24 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// C_sub_01 += transpose(A_sub_0) * B_sub_1 // C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run( threadwise_gemm.Run(
p_a_thread, p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)), p_b_thread + b_thread_mtx_desc_.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC))); p_c_thread + c_thread_mtx_desc.CalculateOffset(make_tuple(0, NPerThreadSubC)));
}); });
// C_sub_10 += transpose(A_sub_1) * B_sub_0 // C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run( threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)), p_a_thread + a_thread_mtx_desc_.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread, p_b_thread,
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, 0))); p_c_thread + c_thread_mtx_desc.CalculateOffset(make_tuple(MPerThreadSubC, 0)));
// C_sub_11 += transpose(A_sub_1) * B_sub_1 // C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run( threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)), p_a_thread + a_thread_mtx_desc_.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)), p_b_thread + b_thread_mtx_desc_.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC))); p_c_thread +
c_thread_mtx_desc.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC)));
} }
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const __device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{ {
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE #if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
......
...@@ -255,6 +255,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -255,6 +255,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
constexpr auto c_m0m1_n0n1_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2( constexpr auto c_m0m1_n0n1_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat * MPerThread>{}, Number<NRepeat * NPerThread>{})); make_tuple(Number<MRepeat * MPerThread>{}, Number<NRepeat * NPerThread>{}));
#if 1 // debug
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v1<BlockSize, BlockwiseGemm_km_kn_m0m1n0n1_v1<BlockSize,
decltype(a_k_m_block_desc), decltype(a_k_m_block_desc),
...@@ -269,6 +270,26 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -269,6 +270,26 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
NLevel1Cluster, NLevel1Cluster,
MPerThread, MPerThread,
NPerThread>{}; NPerThread>{};
#else
const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v1r1<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k_m_block_desc),
decltype(b_k_n_block_desc),
decltype(c_m0m1_n0n1_thread_desc),
MPerThread,
NPerThread,
KPerThread,
MLevel0Cluster,
NLevel0Cluster,
MLevel1Cluster,
NLevel1Cluster,
MPerThread,
NPerThread>{};
#endif
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
......
...@@ -1330,6 +1330,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1330,6 +1330,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// 2. a reference dst_reference_idx is given at compile-time, dst_slice_origin_idx has a // 2. a reference dst_reference_idx is given at compile-time, dst_slice_origin_idx has a
// compile-time distance to dst_reference_idx // compile-time distance to dst_reference_idx
// 3. use direct address calculation (lower of coordinate) // 3. use direct address calculation (lower of coordinate)
// 3. vector access on src
template < template <
typename SrcData, typename SrcData,
typename DstData, typename DstData,
...@@ -1355,34 +1356,98 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1355,34 +1356,98 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{})); using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{}));
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v4(const Index& src_ref_idx) __device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v4(const Index& src_ref_idx)
: src_ref_idx_(make_dynamic_tensor_coordinate(SrcDesc{}, src_ref_idx)) : src_ref_coord_(make_dynamic_tensor_coordinate(SrcDesc{}, src_ref_idx))
{ {
static_assert(SrcDesc::IsKnownAtCompileTime && DstDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time"); "wrong! SrcDesc and DstDesc need to known at compile-time");
} }
template <typename SrcOriginToRefDistance, typename DstOriginToRefDistance> template <typename SrcRefToOriginDisplacement, typename DstRefToOriginDisplacement>
__device__ void Run(const SrcDesc& const SrcOriginToRefDistance& const SrcData* p_src, __device__ void Run(const SrcDesc&,
const SrcRefToOriginDisplacement&,
const SrcData* p_src,
const DstDesc&, const DstDesc&,
const DstOriginToRefDistance& DstData* p_dst) const DstRefToOriginDisplacement&,
DstData* p_dst) const
{ {
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time");
static_assert(is_known_at_compile_time<
remove_cv_t<remove_reference_t<SrcRefToOriginDisplacement>>>::value &&
is_known_at_compile_time<
remove_cv_t<remove_reference_t<DstRefToOriginDisplacement>>>::value,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time");
// SrcDesc and DstDesc are known at compile-time
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{};
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{};
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr auto src_ref_to_origin_disp_idx = SrcRefToOriginDisplacement{};
constexpr auto dst_ref_to_origin_disp_idx = DstRefToOriginDisplacement{};
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// scalar per access of each dim
constexpr auto src_scalar_per_access = generate_sequence_v2(
[&](auto i) constexpr {
if constexpr(i == SrcVectorDim)
{
return Number<SrcScalarPerVector>{};
}
else
{
return Number<1>{};
}
},
Number<nDim>{});
// scalar step (if steping on SrcVectorDim) of each dim
constexpr auto src_scalar_step_in_vector = generate_sequence_v2(
[&](auto i) constexpr {
if constexpr(i == SrcVectorDim)
{
return Number<1>{};
}
else
{
return Number<0>{};
}
},
Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto dim_access_order = DimAccessOrder{};
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) { static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// position in slice window // position in slice window
constexpr auto data_to_origin_dist_idx = #if 0 // debug
constexpr auto data_to_origin_disp_idx =
container_reorder_given_old2new(ordered_access_idx, dim_access_order) * container_reorder_given_old2new(ordered_access_idx, dim_access_order) *
src_scalar_per_access; src_scalar_per_access;
#else
constexpr auto data_to_origin_disp_idx =
ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access;
#endif
// src coordinate // src coordinate
constexpr auto src_data_to_ref_dist_idx = constexpr auto src_ref_to_data_disp_idx =
SrcOriginToRefDistance{} + data_to_origin_dist_idx; to_multi_index(src_ref_to_origin_disp_idx + data_to_origin_disp_idx);
constexpr auto src_data_to_ref_dist_coord_iterator = constexpr auto src_ref_to_data_disp_coord_iterator =
make_dynamic_tensor_coordinate_iterator(SrcDesc{}, src_data_to_ref_dist); make_dynamic_tensor_coordinate_iterator(src_desc, src_ref_to_data_disp_idx);
auto src_data_coord = src_ref_coord_; auto src_data_coord = src_ref_coord_;
move_dynamic_tensor_coordinate( move_dynamic_tensor_coordinate(
src_data_coord, src_data_coord, src_data_to_ref_coord_iterator); src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator);
// copy data from src into buffer // copy data from src into buffer
StaticBuffer<SrcData, SrcScalarPerVector> src_buf; StaticBuffer<SrcData, SrcScalarPerVector> src_buf;
...@@ -1391,20 +1456,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1391,20 +1456,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type; typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_coord_); src_desc, src_data_coord);
src_buf.template AsType<src_vector_t>()(Number<0>{}) = src_buf.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid ? *reinterpret_cast<const src_vector_t*>( is_src_valid
&p_src[src_slice_origin_coord_.GetOffset()]) ? *reinterpret_cast<const src_vector_t*>(&p_src[src_data_coord.GetOffset()])
: src_vector_t{0}; : src_vector_t{0};
// copy data from buffer into dst // copy data from buffer into dst
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr auto src_scalar_step_in_vector = generate_sequence(
detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
constexpr index_t dst_offset = dst_desc.CalculateOffset( constexpr index_t dst_offset = dst_desc.CalculateOffset(
to_multi_index(DstOriginToRefDistance{}) + data_to_origin_dist_idx + to_multi_index(dst_ref_to_origin_disp_idx) + data_to_origin_disp_idx +
i * src_scalar_step_in_vector); i * src_scalar_step_in_vector);
p_dst[Number<dst_offset>{}] = src_buf.template AsType<SrcData>()[i]; p_dst[Number<dst_offset>{}] = src_buf.template AsType<SrcData>()[i];
...@@ -1413,7 +1475,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1413,7 +1475,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
} }
private: private:
SrcCoord src_ref_idx_; SrcCoord src_ref_coord_;
}; };
} // namespace ck } // namespace ck
......
...@@ -5,11 +5,26 @@ ...@@ -5,11 +5,26 @@
namespace ck { namespace ck {
template <index_t... Is>
__host__ __device__ constexpr auto make_sequence(Number<Is>...)
{
return Sequence<Is...>{};
}
// F returns index_t
template <typename F, index_t N> template <typename F, index_t N>
__host__ __device__ constexpr auto generate_sequence(F, Number<N>) __host__ __device__ constexpr auto generate_sequence(F, Number<N>)
{ {
return typename sequence_gen<N, F>::type{}; return typename sequence_gen<N, F>::type{};
} }
// F returns Number<>
template <typename F, index_t N>
__host__ __device__ constexpr auto generate_sequence_v2(F&& f, Number<N>)
{
return unpack([&f](auto&&... xs) { return make_sequence(f(xs)...); },
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
} // namespace ck } // namespace ck
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment