"git@developer.sourcefind.cn:yangql/composable_kernel-1.git" did not exist on "5696c81ffd9f5a7554ce7c47a9d5ed21284ae4f0"
Commit 43cd8529 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 04c5527d
...@@ -39,8 +39,8 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc, ...@@ -39,8 +39,8 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
Tensor<T> wei_csrk(make_TensorDescriptor(wei_csrk_desc)); Tensor<T> wei_csrk(make_TensorDescriptor(wei_csrk_desc));
auto f_reorder_kcsr2csrk = [&](auto k, auto c, auto s, auto r) { auto f_reorder_kcsr2csrk = [&](auto k, auto c, auto y, auto x) {
wei_csrk(c, s, r, k) = wei_kcsr(k, c, s, r); wei_csrk(c, y, x, k) = wei_kcsr(k, c, y, x);
}; };
make_ParallelTensorFunctor(f_reorder_kcsr2csrk, K, C, Y, X)( make_ParallelTensorFunctor(f_reorder_kcsr2csrk, K, C, Y, X)(
......
...@@ -41,8 +41,8 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(InDesc, ...@@ -41,8 +41,8 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(InDesc,
Tensor<T> wei_csrk(make_TensorDescriptor(wei_csrk_desc)); Tensor<T> wei_csrk(make_TensorDescriptor(wei_csrk_desc));
auto f_reorder_kcsr2csrk = [&](auto k, auto c, auto s, auto r) { auto f_reorder_kcsr2csrk = [&](auto k, auto c, auto y, auto x) {
wei_csrk(c, s, r, k) = wei_kcsr(k, c, s, r); wei_csrk(c, y, x, k) = wei_kcsr(k, c, y, x);
}; };
make_ParallelTensorFunctor(f_reorder_kcsr2csrk, K, C, Y, X)( make_ParallelTensorFunctor(f_reorder_kcsr2csrk, K, C, Y, X)(
......
...@@ -55,7 +55,7 @@ void device_implicit_gemm_convolution_2_chwn_csrk_khwn(InDesc, ...@@ -55,7 +55,7 @@ void device_implicit_gemm_convolution_2_chwn_csrk_khwn(InDesc,
Tensor<T> wei_csrk(make_TensorDescriptor(wei_csrk_desc)); Tensor<T> wei_csrk(make_TensorDescriptor(wei_csrk_desc));
make_ParallelTensorFunctor( make_ParallelTensorFunctor(
[&](auto k, auto c, auto s, auto r) { wei_csrk(c, s, r, k) = wei_kcsr(k, c, s, r); }, [&](auto k, auto c, auto y, auto x) { wei_csrk(c, y, x, k) = wei_kcsr(k, c, y, x); },
K, K,
C, C,
Y, Y,
......
...@@ -204,12 +204,12 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric ...@@ -204,12 +204,12 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
__syncthreads(); __syncthreads();
// a series of batched GEMM // a series of batched GEMM
for(unsigned s = 0; s < Y; ++s) for(unsigned y = 0; y < Y; ++y)
{ {
for(unsigned r = 0; r < X; ++r) for(unsigned x = 0; x < X; ++x)
{ {
blockwise_batch_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0), blockwise_batch_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, y, x, 0),
p_in_block + in_chwn_block_desc.Get1dIndex(0, s, r, 0), p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0),
p_out_thread, p_out_thread,
[](auto& acc, const auto&& v) { acc += v; }); [](auto& acc, const auto&& v) { acc += v; });
} }
......
...@@ -245,14 +245,14 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded( ...@@ -245,14 +245,14 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(
__syncthreads(); __syncthreads();
// a series of batched GEMM // a series of batched GEMM
for(unsigned s = 0; s < Y; ++s) for(unsigned y = 0; y < Y; ++y)
{ {
for(unsigned r = 0; r < X; ++r) for(unsigned x = 0; x < X; ++x)
{ {
auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
blockwise_batch_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0), blockwise_batch_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, y, x, 0),
p_in_block + in_chwn_block_desc.Get1dIndex(0, s, r, 0), p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0),
p_out_thread, p_out_thread,
f_accum); f_accum);
} }
......
...@@ -275,9 +275,9 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_b ...@@ -275,9 +275,9 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_b
// compute on current data // compute on current data
// a series of GEMM // a series of GEMM
for(unsigned s = 0; s < Y; ++s) for(unsigned y = 0; y < Y; ++y)
{ {
for(unsigned r = 0; r < X; ++r) for(unsigned x = 0; x < X; ++x)
{ {
auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
#if 1 #if 1
...@@ -285,8 +285,8 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_b ...@@ -285,8 +285,8 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_b
#else #else
blockwise_gemm.Run_RegisterDoubleBuffer blockwise_gemm.Run_RegisterDoubleBuffer
#endif #endif
(p_wei_block_now + wei_csrk_block_desc.Get1dIndex(0, s, r, 0), (p_wei_block_now + wei_csrk_block_desc.Get1dIndex(0, y, x, 0),
p_in_block_now + s * Wi + r, p_in_block_now + y * Wi + x,
p_out_thread, p_out_thread,
f_accum); f_accum);
} }
...@@ -305,9 +305,9 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_b ...@@ -305,9 +305,9 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_b
__syncthreads(); __syncthreads();
for(unsigned s = 0; s < Y; ++s) for(unsigned y = 0; y < Y; ++y)
{ {
for(unsigned r = 0; r < X; ++r) for(unsigned x = 0; x < X; ++x)
{ {
auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
#if 0 #if 0
...@@ -315,8 +315,8 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_b ...@@ -315,8 +315,8 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_b
#else #else
blockwise_gemm.Run_RegisterDoubleBuffer blockwise_gemm.Run_RegisterDoubleBuffer
#endif #endif
(p_wei_block_now + wei_csrk_block_desc.Get1dIndex(0, s, r, 0), (p_wei_block_now + wei_csrk_block_desc.Get1dIndex(0, y, x, 0),
p_in_block_now + s * Wi + r, p_in_block_now + y * Wi + x,
p_out_thread, p_out_thread,
f_accum); f_accum);
} }
......
...@@ -8,16 +8,16 @@ ...@@ -8,16 +8,16 @@
#include <iostream> #include <iostream>
template <class Range> template <class Range>
std::ostream& LogRange(std::ostream& os, Range&& r, std::string delim) std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim)
{ {
bool first = true; bool first = true;
for(auto&& x : r) for(auto&& v : range)
{ {
if(first) if(first)
first = false; first = false;
else else
os << delim; os << delim;
os << x; os << v;
} }
return os; return os;
} }
......
...@@ -38,16 +38,16 @@ __device__ void threadwise_direct_convolution_1(InDesc, ...@@ -38,16 +38,16 @@ __device__ void threadwise_direct_convolution_1(InDesc,
{ {
for(unsigned c = 0; c < wei_desc.GetLength(I1); ++c) for(unsigned c = 0; c < wei_desc.GetLength(I1); ++c)
{ {
for(unsigned s = 0; s < wei_desc.GetLength(I2); ++s) for(unsigned y = 0; y < wei_desc.GetLength(I2); ++y)
{ {
for(unsigned r = 0; r < wei_desc.GetLength(I3); ++r) for(unsigned x = 0; x < wei_desc.GetLength(I3); ++x)
{ {
const unsigned hi = ho + s; const unsigned hi = ho + y;
const unsigned wi = wo + r; const unsigned wi = wo + x;
const unsigned in_index = in_desc.Get1dIndex(n, c, hi, wi); const unsigned in_index = in_desc.Get1dIndex(n, c, hi, wi);
const unsigned wei_index = wei_desc.Get1dIndex(k, c, s, r); const unsigned wei_index = wei_desc.Get1dIndex(k, c, y, x);
const unsigned out_index = out_desc.Get1dIndex(n, k, ho, wo); const unsigned out_index = out_desc.Get1dIndex(n, k, ho, wo);
...@@ -153,18 +153,18 @@ __device__ void threadwise_direct_convolution_3(InDesc, ...@@ -153,18 +153,18 @@ __device__ void threadwise_direct_convolution_3(InDesc,
#if 0 #if 0
// this verison reused old input data in register, and read new data from LDS // this verison reused old input data in register, and read new data from LDS
// loop over vertical direction // loop over vertical direction
for(unsigned s = 0; s < wei_desc.GetLength(I2); ++s) for(unsigned y = 0; y < wei_desc.GetLength(I2); ++y)
{ {
// read first input // read first input
threadwise_4d_tensor_copy(in_desc, threadwise_4d_tensor_copy(in_desc,
p_in + in_desc.Get1dIndex(0, 0, s, 0), p_in + in_desc.Get1dIndex(0, 0, y, 0),
in_reg_desc, in_reg_desc,
p_in_reg, p_in_reg,
in_reg_desc.GetLengths()); in_reg_desc.GetLengths());
// read first 1x1 weight // read first 1x1 weight
threadwise_4d_tensor_copy(wei_desc, threadwise_4d_tensor_copy(wei_desc,
p_wei + wei_desc.Get1dIndex(0, 0, s, 0), p_wei + wei_desc.Get1dIndex(0, 0, y, 0),
wei_reg_desc, wei_reg_desc,
p_wei_reg, p_wei_reg,
wei_reg_desc.GetLengths()); wei_reg_desc.GetLengths());
...@@ -174,11 +174,11 @@ __device__ void threadwise_direct_convolution_3(InDesc, ...@@ -174,11 +174,11 @@ __device__ void threadwise_direct_convolution_3(InDesc,
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out); in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
// loop over horizontal direction // loop over horizontal direction
for(unsigned r = 1; r < wei_desc.GetLength(I3); ++r) for(unsigned x = 1; x < wei_desc.GetLength(I3); ++x)
{ {
// read new weight // read new weight
threadwise_4d_tensor_copy(wei_desc, threadwise_4d_tensor_copy(wei_desc,
p_wei + wei_desc.Get1dIndex(0, 0, s, r), p_wei + wei_desc.Get1dIndex(0, 0, y, x),
wei_reg_desc, wei_reg_desc,
p_wei_reg, p_wei_reg,
wei_reg_desc.GetLengths()); wei_reg_desc.GetLengths());
...@@ -189,7 +189,7 @@ __device__ void threadwise_direct_convolution_3(InDesc, ...@@ -189,7 +189,7 @@ __device__ void threadwise_direct_convolution_3(InDesc,
// read new input // read new input
threadwise_4d_tensor_copy( threadwise_4d_tensor_copy(
in_desc, in_desc,
p_in + in_desc.Get1dIndex(0, 0, s, r + in_reg_desc.GetLength(I3) - 1), p_in + in_desc.Get1dIndex(0, 0, y, x + in_reg_desc.GetLength(I3) - 1),
in_reg_desc, in_reg_desc,
p_in_reg + p_in_reg +
in_reg_desc.Get1dIndex(0, 0, 0, in_reg_desc.GetLength(I3) - in_w_new_read), in_reg_desc.Get1dIndex(0, 0, 0, in_reg_desc.GetLength(I3) - in_w_new_read),
...@@ -203,21 +203,21 @@ __device__ void threadwise_direct_convolution_3(InDesc, ...@@ -203,21 +203,21 @@ __device__ void threadwise_direct_convolution_3(InDesc,
#elif 1 #elif 1
// this version read all input from LDS when filter moves // this version read all input from LDS when filter moves
// loop over vertical direction // loop over vertical direction
for(unsigned s = 0; s < wei_desc.GetLength(I2); ++s) for(unsigned y = 0; y < wei_desc.GetLength(I2); ++y)
{ {
// loop over horizontal direction // loop over horizontal direction
for(unsigned r = 0; r < wei_desc.GetLength(I3); ++r) for(unsigned x = 0; x < wei_desc.GetLength(I3); ++x)
{ {
// read new weight // read new weight
threadwise_4d_tensor_copy(wei_desc, threadwise_4d_tensor_copy(wei_desc,
p_wei + wei_desc.Get1dIndex(0, 0, s, r), p_wei + wei_desc.Get1dIndex(0, 0, y, x),
wei_reg_desc, wei_reg_desc,
p_wei_reg, p_wei_reg,
wei_reg_desc.GetLengths()); wei_reg_desc.GetLengths());
// read new input // read new input
threadwise_4d_tensor_copy(in_desc, threadwise_4d_tensor_copy(in_desc,
p_in + in_desc.Get1dIndex(0, 0, s, r), p_in + in_desc.Get1dIndex(0, 0, y, x),
in_reg_desc, in_reg_desc,
p_in_reg, p_in_reg,
in_reg_desc.GetLengths()); in_reg_desc.GetLengths());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment