Commit 683e2596 authored by root's avatar root
Browse files

rename

parent 8fb97941
...@@ -18,7 +18,7 @@ template <index_t BlockSize, ...@@ -18,7 +18,7 @@ template <index_t BlockSize,
index_t KPerThread, index_t KPerThread,
index_t HPerThread, index_t HPerThread,
index_t WPerThread, index_t WPerThread,
index_t CYXPerThreadLoop, index_t EPerThreadLoop,
index_t ThreadGemmADataPerRead_K, index_t ThreadGemmADataPerRead_K,
index_t ThreadGemmBDataPerRead_W> index_t ThreadGemmBDataPerRead_W>
struct BlockwiseGemm_km_kn_m0m1n0n1_v3 struct BlockwiseGemm_km_kn_m0m1n0n1_v3
...@@ -130,14 +130,14 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -130,14 +130,14 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr auto a_block_mtx = BlockMatrixA{}; constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto CYXPerBlock = a_block_mtx.GetLength(I0); constexpr auto EPerBlock = a_block_mtx.GetLength(I0);
// thread A, B for GEMM // thread A, B for GEMM
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2( constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<CYXPerThreadLoop>{}, Number<KPerThread>{})); make_tuple(Number<EPerThreadLoop>{}, Number<KPerThread>{}));
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<CYXPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{})); Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
constexpr auto c_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( constexpr auto c_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{})); Number<KPerThread>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
...@@ -146,7 +146,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -146,7 +146,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr auto a_thread_copy = ThreadwiseSliceCopy_a<BlockMatrixA, constexpr auto a_thread_copy = ThreadwiseSliceCopy_a<BlockMatrixA,
decltype(a_thread_mtx), decltype(a_thread_mtx),
CYXPerThreadLoop, EPerThreadLoop,
KPerThread, KPerThread,
ThreadGemmADataPerRead_K>{}; ThreadGemmADataPerRead_K>{};
...@@ -155,15 +155,15 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -155,15 +155,15 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
decltype(c_thread_mtx)>{}; decltype(c_thread_mtx)>{};
// loop over k // loop over k
#pragma unroll #pragma unroll
for(index_t cyx_begin = 0; cyx_begin < CYXPerBlock; cyx_begin += CYXPerThreadLoop) for(index_t e_begin = 0; e_begin < EPerBlock; e_begin += EPerThreadLoop)
{ {
a_thread_copy.Run(p_a_block + a_block_mtx.CalculateOffset(make_tuple(cyx_begin, 0)) + a_thread_copy.Run(p_a_block + a_block_mtx.CalculateOffset(make_tuple(e_begin, 0)) +
mMyThreadOffsetA, mMyThreadOffsetA,
p_a_thread); p_a_thread);
threadwise_gemm.Run(p_a_thread, threadwise_gemm.Run(p_a_thread,
p_b_thread + p_b_thread +
b_thread_mtx.CalculateOffset(make_tuple(cyx_begin, 0, 0, 0)), b_thread_mtx.CalculateOffset(make_tuple(e_begin, 0, 0, 0)),
p_c_thread); p_c_thread);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment