Commit 2058bec8 authored by Jing Zhang's avatar Jing Zhang
Browse files

fused functions

parent 766b0a9e
...@@ -379,8 +379,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -379,8 +379,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
// loop over k // loop over k
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{ {
#pragma unroll
// copy A-sub to form A // copy A-sub to form A
#if 0
#pragma unroll
// MRepeat = 2
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ {
threadwise_matrix_copy( threadwise_matrix_copy(
...@@ -391,9 +393,22 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -391,9 +393,22 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths()); a_thread_sub_mtx.GetLengths());
} }
#else
{
auto src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA;
auto dst_index = a_thread_sub_mtx.Get1dIndex(0, 0);
#pragma unroll const float4* loc = (const float4 *)(p_a_block + src_index);
float4* reg = (float4 *)(p_a_thread + dst_index);
reg[0] = loc[0];
reg[MPerThreadSubC/4] = loc[MPerLevel1Cluster/4];
}
#endif
#if 0
// copy B-sub to form B // copy B-sub to form B
#pragma unroll
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ {
threadwise_matrix_copy( threadwise_matrix_copy(
...@@ -404,8 +419,21 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -404,8 +419,21 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC), p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths()); b_thread_sub_mtx.GetLengths());
} }
#else
{
auto src_index = b_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetB;
auto dst_index = b_thread_sub_mtx.Get1dIndex(0, 0);
const float4* loc = (const float4 *)(p_b_block + src_index);
float4* reg = (float4 *)(p_b_thread + dst_index);
reg[0] = loc[0];
reg[NPerThreadSubC/4] = loc[NPerLevel1Cluster/4];
}
#endif
// C = A * B // C = A * B
#if 0
threadwise_gemm(a_thread_mtx, threadwise_gemm(a_thread_mtx,
True, True,
p_a_thread, p_a_thread,
...@@ -416,6 +444,24 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -416,6 +444,24 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
False, False,
p_c_thread, p_c_thread,
f_accum); f_accum);
#else
for(index_t k = 0; k < 1; ++k)
{
// M = 8
for(index_t i = 0; i < 8; ++i)
{
// N = 8
for(index_t j = 0; j < 8; ++j)
{
const index_t aindex = a_thread_sub_mtx.Get1dIndex(k, i); // A is transposed
const index_t bindex = b_thread_sub_mtx.Get1dIndex(k, j);
const index_t cindex = c_thread_mtx.Get1dIndex(i, j);
p_c_thread[cindex] += p_a_thread[aindex] * p_b_thread[bindex];
}
}
}
#endif
} }
} }
......
...@@ -236,7 +236,7 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric ...@@ -236,7 +236,7 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
for(index_t x = 0; x < X; ++x) for(index_t x = 0; x < X; ++x)
{ {
auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
#if 0 #if 1
blockwise_gemm.Run blockwise_gemm.Run
#elif 0 #elif 0
blockwise_gemm.Run_asm blockwise_gemm.Run_asm
......
...@@ -10,9 +10,11 @@ __device__ void threadwise_matrix_copy(SrcMatrix, ...@@ -10,9 +10,11 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
constexpr auto src_mtx = SrcMatrix{}; constexpr auto src_mtx = SrcMatrix{};
constexpr auto dst_mtx = DstMatrix{}; constexpr auto dst_mtx = DstMatrix{};
#if 0 #if 1
//NRow = 1
for(index_t i = 0; i < NRow; ++i) for(index_t i = 0; i < NRow; ++i)
{ {
//NCol = 4
for(index_t j = 0; j < NCol; ++j) for(index_t j = 0; j < NCol; ++j)
{ {
const index_t src_index = src_mtx.Get1dIndex(i, j); const index_t src_index = src_mtx.Get1dIndex(i, j);
...@@ -76,10 +78,13 @@ __device__ void threadwise_gemm(MatrixA, ...@@ -76,10 +78,13 @@ __device__ void threadwise_gemm(MatrixA,
constexpr index_t N = c_mtx.NCol(); constexpr index_t N = c_mtx.NCol();
constexpr index_t K = a_mtx.NRow(); // A is transposed constexpr index_t K = a_mtx.NRow(); // A is transposed
// K = 1
for(index_t k = 0; k < K; ++k) for(index_t k = 0; k < K; ++k)
{ {
// M = 8
for(index_t i = 0; i < M; ++i) for(index_t i = 0; i < M; ++i)
{ {
// N = 8
for(index_t j = 0; j < N; ++j) for(index_t j = 0; j < N; ++j)
{ {
const index_t aindex = a_mtx.Get1dIndex(k, i); // A is transposed const index_t aindex = a_mtx.Get1dIndex(k, i); // A is transposed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment