Commit 753b98b5 authored by Jing Zhang's avatar Jing Zhang
Browse files

refactor inline asm

parent 114fdb58
This diff is collapsed.
#pragma once #pragma once
#include "inline_asm.hpp"
template <class Float, class SrcMatrix, class DstMatrix, index_t NRow, index_t NCol> template <class Float, class SrcMatrix, class DstMatrix, index_t NRow, index_t NCol>
__device__ void threadwise_matrix_copy(SrcMatrix, __device__ void threadwise_matrix_copy(SrcMatrix,
const Float* __restrict__ p_src, const Float* __restrict__ p_src,
...@@ -21,18 +23,18 @@ __device__ void threadwise_matrix_copy(SrcMatrix, ...@@ -21,18 +23,18 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
p_dst[dst_index] = p_src[src_index]; p_dst[dst_index] = p_src[src_index];
} }
} }
#elif 1 #else
static_assert(NCol == 4, "only for NCol == 4"); static_assert(NCol == 4, "only for NCol == 4");
using vector_t = typename vector_type<Float, 4>::MemoryType;
for(index_t i = 0; i < NRow; ++i) for(index_t i = 0; i < NRow; ++i)
{ {
const index_t src_index = src_mtx.Get1dIndex(i, 0); const index_t src_index = src_mtx.Get1dIndex(i, 0);
const index_t dst_index = dst_mtx.Get1dIndex(i, 0); const index_t dst_index = dst_mtx.Get1dIndex(i, 0);
*(reinterpret_cast<vector_t*>(&p_dst[dst_index])) = Float4 *reg_p = (Float4 *)&p_dst[dst_index];
*(reinterpret_cast<const vector_t*>(&p_src[src_index])); Float4 *loc_p = (Float4 *)&p_src[src_index];
ds_read_b128(reg_p[0], (void *)&loc_p[0]);
} }
#endif #endif
} }
...@@ -70,25 +72,20 @@ __device__ void threadwise_gemm(MatrixA, ...@@ -70,25 +72,20 @@ __device__ void threadwise_gemm(MatrixA,
for(index_t k = 0; k < K; ++k) for(index_t k = 0; k < K; ++k)
{ {
for(index_t i = 0; i < M; ++i) for(index_t i = 0; i < M; i+=4)
{
for(index_t j = 0; j < N; ++j)
{ {
const index_t aindex = a_mtx.Get1dIndex(k, i); // A is transposed const index_t aindex = a_mtx.Get1dIndex(k, i); // A is transposed
const Float4 *a_vec = (const Float4 *)&p_a_thread[aindex];
for(index_t j = 0; j < N; j+=4)
{
const index_t bindex = b_mtx.Get1dIndex(k, j); const index_t bindex = b_mtx.Get1dIndex(k, j);
const index_t cindex = c_mtx.Get1dIndex(i, j); const index_t cindex = c_mtx.Get1dIndex(i, j);
#if 0 const Float4 *b_vec = (const Float4 *)&p_b_thread[bindex];
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]); Float4 *c_vec = (Float4 *)&p_c_thread[cindex];
#elif 1
asm volatile("\n \ outerProduct4x4(a_vec[0], b_vec[0], c_vec[0], c_vec[2], c_vec[4], c_vec[6]);
v_mac_f32 %0, %1, %2 \n \
"
: "=v"(p_c_thread[cindex])
: "v"(p_a_thread[aindex]),
"v"(p_b_thread[bindex]),
"0"(p_c_thread[cindex]));
#endif
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment