"tests/vscode:/vscode.git/clone" did not exist on "717d15719c713fd3ee9ab0d8eb3d98116758036e"
Commit e38c1b73 authored by Chao Liu's avatar Chao Liu
Browse files

replacing array with vector for tensor data

parent 841b1480
......@@ -130,13 +130,13 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
// thread A, B for GEMM
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
Number<KPerThreadLoop>{}, Number<MPerThread>{});
make_tuple(Number<KPerThreadLoop>{}, Number<MPerThread>{}));
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
Number<KPerThreadLoop>{}, Number<NPerThread>{});
make_tuple(Number<KPerThreadLoop>{}, Number<NPerThread>{}));
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
FloatA p_a_thread[a_thread_mtx.GetElementSpaceSize()];
FloatB p_b_thread[b_thread_mtx.GetElementSpaceSize()];
constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy_v2<BlockMatrixA,
decltype(a_thread_mtx),
......@@ -153,37 +153,31 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_mtx),
decltype(b_thread_mtx),
decltype(c_thread_mtx)>{};
#pragma unroll
// loop over k
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{
#pragma unroll
static_for<0, K, KPerThreadLoop>{}([&](auto k_begin) {
// read A
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
static_for<0, MRepeat, 1>{}([&](auto m_repeat) {
a_thread_copy.Run(p_a_block +
a_block_mtx.CalculateOffset(
make_tuple(k_begin, m_repeat * MPerLevel1Cluster)) +
mMyThreadOffsetA,
p_a_thread + a_thread_mtx.CalculateOffset(
make_tuple(0, m_repeat * MPerThreadSubC)));
}
});
#pragma unroll
// read B
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
static_for<0, NRepeat, 1>{}([&](auto n_repeat) {
b_thread_copy.Run(p_b_block +
b_block_mtx.CalculateOffset(
make_tuple(k_begin, n_repeat * NPerLevel1Cluster)) +
mMyThreadOffsetB,
p_b_thread + b_thread_mtx.CalculateOffset(
make_tuple(0, n_repeat * NPerThreadSubC)));
}
});
// C += A * B
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
}
});
}
template <typename FloatA, typename FloatB, typename FloatC>
......
......@@ -27,10 +27,14 @@ struct lambda_scalar_step_in_vector
}
};
// Assume:
// 1. src_desc is known at compile-time
// 2. dst_desc is not known at compile-time
// 3. src_slice_origin_idx is known at compile-time and it's 0
// 4. dst_slice_origin_idx is not-known at compile time
// this version is less likely to have scratch memory issue, due to:
// 1. It does not keep reference to tensor descriptor
// 2. It does not construct new tensor coordinate for this->Run()
// Assume src_slice_origin_idx is 0
// TODO: support non-zero src_slice_oring_idx
template <typename SrcData,
typename DstData,
......@@ -359,10 +363,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
DstCoord dst_slice_origin_coord_;
}; // namespace ck
// Assume:
// 1. src_desc is not known at compile-time
// 2. dst_desc is known at compile-time
// 3. src_slice_origin_idx is not known at compile-time
// 4. dst_slice_origin_idx is known at compile-time and it's 0
// this version is less likely to have scratch memory issue, due to:
// 1. It does not keep reference to tensor descriptor
// 2. It does not construct new tensor coordinate for this->Run()
// Assume dst_slice_origin_idx is 0
template <typename SrcData,
typename DstData,
typename SrcDesc,
......@@ -590,7 +598,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
}
}
__device__ void Run(const SrcDesc& src_desc, const SrcData* p_src, DstData* p_dst)
template <typename DstSliceOriginIdx>
__device__ void Run(const SrcDesc& src_desc,
const SrcData* p_src,
const DstDesc&,
const DstSliceOriginIdx&,
DstData* p_dst)
{
constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform();
......@@ -600,7 +613,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
Run(src_desc, p_src, p_dst, src_iterator_hacks);
Run(src_desc, p_src, DstDesc{}, DstSliceOriginIdx{}, p_dst, src_iterator_hacks);
}
__device__ static constexpr auto GetSrcCoordinateResetStep()
......@@ -685,6 +698,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
SrcCoord src_slice_origin_coord_;
}; // namespace ck
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. src_slice_origin and dst_slice_origin are not known at compile-time,
// 3. Use thread buffer
// this version does following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions
// 1. It does not keep reference to tensor descriptor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment