Commit e38c1b73 authored by Chao Liu's avatar Chao Liu
Browse files

replacing array with vector for tensor data

parent 841b1480
...@@ -130,13 +130,13 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1 ...@@ -130,13 +130,13 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
// thread A, B for GEMM // thread A, B for GEMM
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2( constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
Number<KPerThreadLoop>{}, Number<MPerThread>{}); make_tuple(Number<KPerThreadLoop>{}, Number<MPerThread>{}));
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2( constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
Number<KPerThreadLoop>{}, Number<NPerThread>{}); make_tuple(Number<KPerThreadLoop>{}, Number<NPerThread>{}));
FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; FloatA p_a_thread[a_thread_mtx.GetElementSpaceSize()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; FloatB p_b_thread[b_thread_mtx.GetElementSpaceSize()];
constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy_v2<BlockMatrixA, constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy_v2<BlockMatrixA,
decltype(a_thread_mtx), decltype(a_thread_mtx),
...@@ -153,37 +153,31 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1 ...@@ -153,37 +153,31 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_mtx), constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_mtx),
decltype(b_thread_mtx), decltype(b_thread_mtx),
decltype(c_thread_mtx)>{}; decltype(c_thread_mtx)>{};
#pragma unroll
// loop over k // loop over k
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) static_for<0, K, KPerThreadLoop>{}([&](auto k_begin) {
{
#pragma unroll
// read A // read A
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) static_for<0, MRepeat, 1>{}([&](auto m_repeat) {
{
a_thread_copy.Run(p_a_block + a_thread_copy.Run(p_a_block +
a_block_mtx.CalculateOffset( a_block_mtx.CalculateOffset(
make_tuple(k_begin, m_repeat * MPerLevel1Cluster)) + make_tuple(k_begin, m_repeat * MPerLevel1Cluster)) +
mMyThreadOffsetA, mMyThreadOffsetA,
p_a_thread + a_thread_mtx.CalculateOffset( p_a_thread + a_thread_mtx.CalculateOffset(
make_tuple(0, m_repeat * MPerThreadSubC))); make_tuple(0, m_repeat * MPerThreadSubC)));
} });
#pragma unroll
// read B // read B
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) static_for<0, NRepeat, 1>{}([&](auto n_repeat) {
{
b_thread_copy.Run(p_b_block + b_thread_copy.Run(p_b_block +
b_block_mtx.CalculateOffset( b_block_mtx.CalculateOffset(
make_tuple(k_begin, n_repeat * NPerLevel1Cluster)) + make_tuple(k_begin, n_repeat * NPerLevel1Cluster)) +
mMyThreadOffsetB, mMyThreadOffsetB,
p_b_thread + b_thread_mtx.CalculateOffset( p_b_thread + b_thread_mtx.CalculateOffset(
make_tuple(0, n_repeat * NPerThreadSubC))); make_tuple(0, n_repeat * NPerThreadSubC)));
} });
// C += A * B // C += A * B
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread); threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
} });
} }
template <typename FloatA, typename FloatB, typename FloatC> template <typename FloatA, typename FloatB, typename FloatC>
......
...@@ -27,10 +27,14 @@ struct lambda_scalar_step_in_vector ...@@ -27,10 +27,14 @@ struct lambda_scalar_step_in_vector
} }
}; };
// Assume:
// 1. src_desc is known at compile-time
// 2. dst_desc is not known at compile-time
// 3. src_slice_origin_idx is known at compile-time and it's 0
// 4. dst_slice_origin_idx is not-known at compile time
// this version is less likely to have scratch memory issue, due to: // this version is less likely to have scratch memory issue, due to:
// 1. It does not keep reference to tensor descriptor // 1. It does not keep reference to tensor descriptor
// 2. It does not construct new tensor coordinate for this->Run() // 2. It does not construct new tensor coordinate for this->Run()
// Assume src_slice_origin_idx is 0
// TODO: support non-zero src_slice_oring_idx // TODO: support non-zero src_slice_oring_idx
template <typename SrcData, template <typename SrcData,
typename DstData, typename DstData,
...@@ -359,10 +363,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -359,10 +363,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
DstCoord dst_slice_origin_coord_; DstCoord dst_slice_origin_coord_;
}; // namespace ck }; // namespace ck
// Assume:
// 1. src_desc is not known at compile-time
// 2. dst_desc is known at compile-time
// 3. src_slice_origin_idx is not known at compile-time
// 4. dst_slice_origin_idx is known at compile-time and it's 0
// this version is less likely to have scratch memory issue, due to: // this version is less likely to have scratch memory issue, due to:
// 1. It does not keep reference to tensor descriptor // 1. It does not keep reference to tensor descriptor
// 2. It does not construct new tensor coordinate for this->Run() // 2. It does not construct new tensor coordinate for this->Run()
// Assume dst_slice_origin_idx is 0
template <typename SrcData, template <typename SrcData,
typename DstData, typename DstData,
typename SrcDesc, typename SrcDesc,
...@@ -590,7 +598,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -590,7 +598,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
} }
} }
__device__ void Run(const SrcDesc& src_desc, const SrcData* p_src, DstData* p_dst) template <typename DstSliceOriginIdx>
__device__ void Run(const SrcDesc& src_desc,
const SrcData* p_src,
const DstDesc&,
const DstSliceOriginIdx&,
DstData* p_dst)
{ {
constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform(); constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform();
...@@ -600,7 +613,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -600,7 +613,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}), make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{})); generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
Run(src_desc, p_src, p_dst, src_iterator_hacks); Run(src_desc, p_src, DstDesc{}, DstSliceOriginIdx{}, p_dst, src_iterator_hacks);
} }
__device__ static constexpr auto GetSrcCoordinateResetStep() __device__ static constexpr auto GetSrcCoordinateResetStep()
...@@ -685,12 +698,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 ...@@ -685,12 +698,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
SrcCoord src_slice_origin_coord_; SrcCoord src_slice_origin_coord_;
}; // namespace ck }; // namespace ck
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. src_slice_origin and dst_slice_origin are not known at compile-time,
// 3. Use thread buffer
// this version does following things to avoid "alloca" in LLVM-IR, which would cause scratch memory // this version does following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions // and sometimes useless instructions
// 1. It does not keep reference to tensor descriptor // 1. It does not keep reference to tensor descriptor
// 2. It does not construct new tensor coordinate for this->Run() // 2. It does not construct new tensor coordinate for this->Run()
// 3. It does not use pointer for VGPR thread buffer // 3. It does not use pointer for VGPR thread buffer
// 4. It calculate offset for thread buffer directly, instead of moving the coordinate // 4. It calculate offset for thread buffer directly, instead of moving the coordinate
template <typename SliceLengths, template <typename SliceLengths,
InMemoryDataOperation DstInMemOp, InMemoryDataOperation DstInMemOp,
typename SrcData, typename SrcData,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment