"git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "5ce60ecc173102452f2546581716b3163708dcc9"
Commit 22114959 authored by Chao Liu's avatar Chao Liu
Browse files

adding inline asm for 16x4 gemm

parent 68ae0731
...@@ -526,7 +526,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -526,7 +526,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
} }
} }
// this version put copy and compute in same place, experimenting with compiler behaviour
template <class FloatA, class FloatB, class FloatC, class Accumulator> template <class FloatA, class FloatB, class FloatC, class Accumulator>
__device__ void Run_v2(const FloatA* __restrict__ p_a_block, __device__ void Run_v2(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block, const FloatB* __restrict__ p_b_block,
...@@ -687,6 +686,231 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -687,6 +686,231 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
} }
} }
template <class FloatA, class FloatB, class FloatC, class Accumulator>
__device__ void Run_v3(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread,
Accumulator f_accum) const
{
constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{};
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
// A is transposed, b is not
constexpr auto a_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
// thread A-sub, B-sub for copy
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
// loop over k
//#pragma unroll
for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
{
// read first batch of A, B
// copy A-sub to form A
//#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
for(unsigned i = 0; i < a_thread_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < a_thread_mtx.NCol(); ++j)
{
p_a_thread[a_thread_mtx.Get1dIndex(i, m_repeat * MPerThreadSubC + j)] =
p_a_block[a_block_mtx.Get1dIndex(k_begin + i,
m_repeat * MPerLevel1Cluster + j) +
mMyThreadOffsetA];
}
}
}
// copy B-sub to form B
//#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
for(unsigned i = 0; i < b_thread_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < b_thread_mtx.NCol(); ++j)
{
p_b_thread[b_thread_mtx.Get1dIndex(i, n_repeat * NPerThreadSubC + j)] =
p_b_block[b_block_mtx.Get1dIndex(k_begin + i,
n_repeat * MPerLevel1Cluster + j) +
mMyThreadOffsetB];
}
}
}
// loop over batch
//#pragma unroll
for(unsigned ib = 0; ib + 1 < BatchPerThread; ++ib)
{
// do current batch of gemm
for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k)
{
#if 0
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j)
{
const unsigned aindex =
a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned bindex = b_thread_mtx.Get1dIndex(k, j);
const unsigned cindex =
c_thread_mtx.Get1dIndex(i, j) + ib * ThreadMatrixStrideC;
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
}
}
#elif 1
static_assert(c_thread_mtx.NRow() == 16 && c_thread_mtx.NCol() == 4,
"asm is only for 16x4");
const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0);
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
{
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned cindex = c_thread_mtx.Get1dIndex(i, 0);
asm volatile("\n \
v_mac_f32 %0, %4, %5 \n \
v_mac_f32 %1, %4, %6 \n \
v_mac_f32 %2, %4, %7 \n \
v_mac_f32 %3, %4, %8 \n \
"
: "=v"(p_c_thread[cindex + 0]),
"=v"(p_c_thread[cindex + 1]),
"=v"(p_c_thread[cindex + 2]),
"=v"(p_c_thread[cindex + 3])
: "v"(p_a_thread[aindex]),
"v"(p_b_thread[bindex + 0]),
"v"(p_b_thread[bindex + 1]),
"v"(p_b_thread[bindex + 2]),
"v"(p_b_thread[bindex + 3]),
"0"(p_c_thread[cindex + 0]),
"1"(p_c_thread[cindex + 1]),
"2"(p_c_thread[cindex + 2]),
"3"(p_c_thread[cindex + 3]));
}
#endif
}
// read next batch of a, b
if(BlockMatrixStrideA != 0)
{
//#pragma unroll
for(unsigned m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
for(unsigned i = 0; i < a_thread_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < a_thread_mtx.NCol(); ++j)
{
p_a_thread[a_thread_mtx.Get1dIndex(i,
m_repeat * MPerThreadSubC + j)] =
p_a_block[a_block_mtx.Get1dIndex(
k_begin + i, m_repeat * MPerLevel1Cluster + j) +
(ib + 1) * BlockMatrixStrideA + mMyThreadOffsetA];
}
}
}
}
if(BlockMatrixStrideB != 0)
{
//#pragma unroll
for(unsigned n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
for(unsigned i = 0; i < b_thread_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < b_thread_mtx.NCol(); ++j)
{
p_b_thread[b_thread_mtx.Get1dIndex(i,
n_repeat * NPerThreadSubC + j)] =
p_b_block[b_block_mtx.Get1dIndex(
k_begin + i, n_repeat * MPerLevel1Cluster + j) +
(ib + 1) * BlockMatrixStrideB + mMyThreadOffsetB];
}
}
}
}
}
// do last batch of gemm
for(unsigned k = 0; k < a_thread_mtx.NRow(); ++k)
{
#if 0
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
{
for(unsigned j = 0; j < c_thread_mtx.NCol(); ++j)
{
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned bindex = b_thread_mtx.Get1dIndex(k, j);
const unsigned cindex = c_thread_mtx.Get1dIndex(i, j) +
(BatchPerThread - 1) * ThreadMatrixStrideC;
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
}
}
#elif 1
static_assert(c_thread_mtx.NRow() == 16 && c_thread_mtx.NCol() == 4,
"asm is only for 16x4");
const unsigned bindex = b_thread_mtx.Get1dIndex(k, 0);
for(unsigned i = 0; i < c_thread_mtx.NRow(); ++i)
{
const unsigned aindex = a_thread_mtx.Get1dIndex(k, i); // A is transposed
const unsigned cindex =
c_thread_mtx.Get1dIndex(i, 0) + (BatchPerThread - 1) * ThreadMatrixStrideC;
asm volatile("\n \
v_mac_f32 %0, %4, %5 \n \
v_mac_f32 %1, %4, %6 \n \
v_mac_f32 %2, %4, %7 \n \
v_mac_f32 %3, %4, %8 \n \
"
: "=v"(p_c_thread[cindex + 0]),
"=v"(p_c_thread[cindex + 1]),
"=v"(p_c_thread[cindex + 2]),
"=v"(p_c_thread[cindex + 3])
: "v"(p_a_thread[aindex]),
"v"(p_b_thread[bindex + 0]),
"v"(p_b_thread[bindex + 1]),
"v"(p_b_thread[bindex + 2]),
"v"(p_b_thread[bindex + 3]),
"0"(p_c_thread[cindex + 0]),
"1"(p_c_thread[cindex + 1]),
"2"(p_c_thread[cindex + 2]),
"3"(p_c_thread[cindex + 3]));
}
#endif
}
}
}
template <class BlockMatrixC, unsigned BlockMatrixStrideC, class FloatC> template <class BlockMatrixC, unsigned BlockMatrixStrideC, class FloatC>
__device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread, __device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread,
FloatC* __restrict__ p_c_block) const FloatC* __restrict__ p_c_block) const
......
...@@ -209,15 +209,17 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric ...@@ -209,15 +209,17 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
{ {
for(unsigned x = 0; x < X; ++x) for(unsigned x = 0; x < X; ++x)
{ {
#if 1 #if 0
blockwise_batch_gemm.Run blockwise_batch_gemm.Run
#elif 0 #elif 0
blockwise_batch_gemm.Run_v2 blockwise_batch_gemm.Run_v2
#elif 1
blockwise_batch_gemm.Run_v3
#endif #endif
(p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), (p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0), p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0),
p_out_thread, p_out_thread,
[](auto& acc, const auto&& v) { acc += v; }); [](auto& acc, const auto&& v) { acc += v; });
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment