Commit 4facbe99 authored by Chao Liu's avatar Chao Liu
Browse files

experiment inline asm: lds read

parent 766b0a9e
...@@ -40,7 +40,7 @@ struct GeneratorTensor_Checkboard ...@@ -40,7 +40,7 @@ struct GeneratorTensor_Checkboard
template <class... Ts> template <class... Ts>
double operator()(Ts... Xs) const double operator()(Ts... Xs) const
{ {
std::array<index_t, sizeof...(Ts)> dims = {{Xs...}}; std::array<index_t, sizeof...(Ts)> dims = {{static_cast<index_t>(Xs)...}};
return std::accumulate(dims.begin(), return std::accumulate(dims.begin(),
dims.end(), dims.end(),
true, true,
...@@ -593,9 +593,9 @@ int main(int argc, char* argv[]) ...@@ -593,9 +593,9 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 1 #elif 1
// 1x1 filter, 14x14 image, C = 512 // 1x1 filter, 14x14 image, C = 256
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 512; constexpr index_t C = 256;
constexpr index_t HI = 14; constexpr index_t HI = 14;
constexpr index_t WI = 14; constexpr index_t WI = 14;
constexpr index_t K = 512; constexpr index_t K = 512;
...@@ -638,10 +638,10 @@ int main(int argc, char* argv[]) ...@@ -638,10 +638,10 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
#if 0 #if 1
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 1 #elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#elif 0 #elif 0
......
...@@ -563,7 +563,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -563,7 +563,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
__device__ void Run_RegisterDoubleBuffer(FloatA* const p_a_block, __device__ void Run_RegisterDoubleBuffer(FloatA* const p_a_block,
FloatB* const p_b_block, FloatB* const p_b_block,
FloatC* p_c_thread, FloatC* p_c_thread,
Accumulator f_accum) const Accumulator f_accum,
float* p_lds_begin) const
{ {
constexpr auto True = integral_constant<bool, true>{}; constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{}; constexpr auto False = integral_constant<bool, false>{};
...@@ -610,21 +611,23 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -610,21 +611,23 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
#pragma unroll #pragma unroll
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ // copy A-sub to form A { // copy A-sub to form A
threadwise_matrix_copy(a_block_mtx, threadwise_matrix_copy_v2(a_block_mtx,
p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster, p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster,
a_thread_sub_mtx, a_thread_sub_mtx,
p_a_thread_0 + m_repeat * MPerThreadSubC, p_a_thread_0 + m_repeat * MPerThreadSubC,
a_thread_sub_mtx.GetLengths()); a_thread_sub_mtx.GetLengths(),
p_lds_begin);
} }
#pragma unroll #pragma unroll
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ // copy B-sub to form B { // copy B-sub to form B
threadwise_matrix_copy(b_block_mtx, threadwise_matrix_copy_v2(b_block_mtx,
p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster, p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster,
b_thread_sub_mtx, b_thread_sub_mtx,
p_b_thread_0 + n_repeat * NPerThreadSubC, p_b_thread_0 + n_repeat * NPerThreadSubC,
b_thread_sub_mtx.GetLengths()); b_thread_sub_mtx.GetLengths(),
p_lds_begin);
} }
bool even_loop = true; bool even_loop = true;
...@@ -643,27 +646,35 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -643,27 +646,35 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
#pragma unroll #pragma unroll
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ // copy A-sub to form A { // copy A-sub to form A
threadwise_matrix_copy(a_block_mtx, threadwise_matrix_copy_v2(a_block_mtx,
p_a_block + mMyThreadOffsetA + p_a_block + mMyThreadOffsetA +
(k_begin + 1) * a_block_mtx.RowStride() + (k_begin + 1) * a_block_mtx.RowStride() +
m_repeat * MPerLevel1Cluster, m_repeat * MPerLevel1Cluster,
a_thread_sub_mtx, a_thread_sub_mtx,
p_a_thread_next + m_repeat * MPerThreadSubC, p_a_thread_next + m_repeat * MPerThreadSubC,
a_thread_sub_mtx.GetLengths()); a_thread_sub_mtx.GetLengths(),
p_lds_begin);
} }
#pragma unroll #pragma unroll
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ // copy B-sub to form B { // copy B-sub to form B
threadwise_matrix_copy(b_block_mtx, threadwise_matrix_copy_v2(b_block_mtx,
p_b_block + mMyThreadOffsetB + p_b_block + mMyThreadOffsetB +
(k_begin + 1) * b_block_mtx.RowStride() + (k_begin + 1) * b_block_mtx.RowStride() +
n_repeat * NPerLevel1Cluster, n_repeat * NPerLevel1Cluster,
b_thread_sub_mtx, b_thread_sub_mtx,
p_b_thread_next + n_repeat * NPerThreadSubC, p_b_thread_next + n_repeat * NPerThreadSubC,
b_thread_sub_mtx.GetLengths()); b_thread_sub_mtx.GetLengths(),
p_lds_begin);
} }
#if 1
asm volatile("\n \
s_waitcnt lgkmcnt(0) \n \
" ::);
#endif
// C = A * B // C = A * B
threadwise_gemm(a_thread_mtx, threadwise_gemm(a_thread_mtx,
True, True,
......
...@@ -206,6 +206,10 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric ...@@ -206,6 +206,10 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
__shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)]; __shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)];
__shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)]; __shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
#if 1
constexpr Float* p_lds_begin = p_wei_block;
#endif
const Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin); p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin);
...@@ -246,7 +250,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric ...@@ -246,7 +250,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
(p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), (p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
p_in_block + y * Wi + x, p_in_block + y * Wi + x,
p_out_thread, p_out_thread,
f_accum); f_accum,
p_lds_begin);
} }
} }
} }
......
...@@ -289,10 +289,10 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( ...@@ -289,10 +289,10 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
#else #else
blockwise_gemm.Run_RegisterDoubleBuffer blockwise_gemm.Run_RegisterDoubleBuffer
#endif #endif
(p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
p_in_block_now + y * Wi + x, p_in_block_now + y * Wi + x,
p_out_thread, p_out_thread,
f_accum); f_accum);
} }
} }
...@@ -319,10 +319,10 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( ...@@ -319,10 +319,10 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
#else #else
blockwise_gemm.Run_RegisterDoubleBuffer blockwise_gemm.Run_RegisterDoubleBuffer
#endif #endif
(p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
p_in_block_now + y * Wi + x, p_in_block_now + y * Wi + x,
p_out_thread, p_out_thread,
f_accum); f_accum);
} }
} }
} }
......
...@@ -10,7 +10,6 @@ __device__ void threadwise_matrix_copy(SrcMatrix, ...@@ -10,7 +10,6 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
constexpr auto src_mtx = SrcMatrix{}; constexpr auto src_mtx = SrcMatrix{};
constexpr auto dst_mtx = DstMatrix{}; constexpr auto dst_mtx = DstMatrix{};
#if 0
for(index_t i = 0; i < NRow; ++i) for(index_t i = 0; i < NRow; ++i)
{ {
for(index_t j = 0; j < NCol; ++j) for(index_t j = 0; j < NCol; ++j)
...@@ -21,7 +20,39 @@ __device__ void threadwise_matrix_copy(SrcMatrix, ...@@ -21,7 +20,39 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
p_dst[dst_index] = p_src[src_index]; p_dst[dst_index] = p_src[src_index];
} }
} }
#elif 1 }
template <class Float, class SrcMatrix, class DstMatrix, index_t NRow, index_t NCol>
__device__ void threadwise_matrix_copy_v2(SrcMatrix,
const Float* __restrict__ p_src,
DstMatrix,
Float* __restrict__ p_dst,
Sequence<NRow, NCol>,
const float* p_lds_begin)
{
constexpr auto src_mtx = SrcMatrix{};
constexpr auto dst_mtx = DstMatrix{};
#if 1
for(index_t i = 0; i < NRow; ++i)
{
for(index_t j = 0; j < NCol; ++j)
{
const index_t src_index = src_mtx.Get1dIndex(i, j);
const index_t dst_index = dst_mtx.Get1dIndex(i, j);
#if 0
p_dst[dst_index] = p_src[src_index];
#else
asm volatile("\n \
ds_read_b32 %0, %1 \n \
"
: "=v"(p_dst[dst_index])
: "v"((uint32_t)((uintptr_t)((p_src + src_index) - p_lds_begin))));
#endif
}
}
#elif 0
static_assert(NCol == 4, "only for NCol == 4"); static_assert(NCol == 4, "only for NCol == 4");
using vector_t = typename vector_type<Float, 4>::MemoryType; using vector_t = typename vector_type<Float, 4>::MemoryType;
...@@ -38,8 +69,8 @@ __device__ void threadwise_matrix_copy(SrcMatrix, ...@@ -38,8 +69,8 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1, offset:0 \n \ ds_read_b128 %0, %1, offset:0 \n \
" "
: "=v"(*(reinterpret_cast<vector_t*>(p_dst+dst_index))) : "=v"(*(reinterpret_cast<vector_t*>(p_dst + dst_index)))
: "v"((uint32_t)(p_src + src_index))); : "v"((uint32_t)((uintptr_t)(p_src + src_index - p_lds_begin))));
#endif #endif
} }
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment