"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "66f8bd6869935e25a8091b00636666d9bfac8d2f"
Commit bd0098af authored by Chao Liu's avatar Chao Liu
Browse files

use dedicated threadwise_copy for 1x1, perf at 80%

parent 5245a016
...@@ -192,7 +192,6 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -192,7 +192,6 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
#elif 0 #elif 0
// 1x1, 14x14, Vega 20, disable lds_double_buffer, enable register double buffer
// 1x1, 14x14, Pascal, enable lds_double_buffer, disable register double buffer // 1x1, 14x14, Pascal, enable lds_double_buffer, disable register double buffer
constexpr index_t BPerBlock = 64; constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 128; constexpr index_t KPerBlock = 128;
...@@ -245,6 +244,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -245,6 +244,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr index_t InBlockCopyDataPerRead = 4; constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t OutThreadCopyDataPerWrite = 4;
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
#endif #endif
...@@ -295,7 +296,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -295,7 +296,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
WeiBlockCopyThreadPerDim0, WeiBlockCopyThreadPerDim0,
WeiBlockCopyThreadPerDim1, WeiBlockCopyThreadPerDim1,
InBlockCopyDataPerRead, InBlockCopyDataPerRead,
WeiBlockCopyDataPerRead>{}; WeiBlockCopyDataPerRead,
OutThreadCopyDataPerWrite>{};
float time = launch_kernel(run_gridwise_convolution<decltype(gridwise_conv), T>, float time = launch_kernel(run_gridwise_convolution<decltype(gridwise_conv), T>,
dim3(GridSize), dim3(GridSize),
......
...@@ -206,11 +206,7 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn ...@@ -206,11 +206,7 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
__syncthreads()) __syncthreads())
{ {
// load data // load data
#if 1
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
#elif 0
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
...@@ -219,9 +215,13 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn ...@@ -219,9 +215,13 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
p_wei_register_clipboard); p_wei_register_clipboard);
#if 1
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block); blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block); blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block);
#else
vmcnt(0);
blockwise_in_copy.RunStoreRegisterClipboard_asm(p_in_register_clipboard, p_in_block);
blockwise_wei_copy.RunStoreRegisterClipboard_asm(p_wei_register_clipboard, p_wei_block);
#endif #endif
__syncthreads(); __syncthreads();
...@@ -232,16 +232,16 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn ...@@ -232,16 +232,16 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
{ {
for(index_t x = 0; x < X; ++x) for(index_t x = 0; x < X; ++x)
{ {
#if 1 #if 0
blockwise_gemm.Run blockwise_gemm.Run
#elif 0 #elif 0
blockwise_gemm.Run_RegisterDoubleBuffer blockwise_gemm.Run_RegisterDoubleBuffer
#elif 0 #elif 1
blockwise_gemm.Run_asm blockwise_gemm.Run_asm
#endif #endif
(p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), (p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
p_in_block + y * Wi + x, p_in_block + y * Wi + x,
p_out_thread); p_out_thread);
} }
} }
} }
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "blockwise_4d_tensor_op.hip.hpp" #include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_2d_tensor_op.hip.hpp" #include "blockwise_2d_tensor_op.hip.hpp"
#include "threadwise_2d_tensor_op.hip.hpp" #include "threadwise_2d_tensor_op.hip.hpp"
#include "threadwise_nd_tensor_op.hip.hpp"
#include "blockwise_gemm.hip.hpp" #include "blockwise_gemm.hip.hpp"
// define B = flatten(N, Hi, Wi) // define B = flatten(N, Hi, Wi)
...@@ -31,7 +32,8 @@ template <index_t GridSize, ...@@ -31,7 +32,8 @@ template <index_t GridSize,
index_t WeiBlockCopyThreadPerDim0, index_t WeiBlockCopyThreadPerDim0,
index_t WeiBlockCopyThreadPerDim1, index_t WeiBlockCopyThreadPerDim1,
index_t InBlockCopyDataPerRead, index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead> index_t WeiBlockCopyDataPerRead,
index_t OutThreadCopyDataPerWrite>
struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
{ {
__device__ void Run(const Float* const __restrict__ p_in_global, __device__ void Run(const Float* const __restrict__ p_in_global,
...@@ -369,25 +371,55 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -369,25 +371,55 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
const index_t k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row; const index_t k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row;
const index_t b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col; const index_t b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col;
for(index_t k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) #if 1
if(Y == 1 && X == 1)
{ // pure 1x1 conv
constexpr index_t K2_ = GemmMPerThreadSubC;
constexpr index_t K1_ = KPerBlock / KPerThread;
constexpr index_t B2_ = GemmNPerThreadSubC;
constexpr index_t B1_ = BPerBlock / BPerThread;
constexpr auto out_6d_global_desc = make_ConstantTensorDescriptor(
Sequence<K / (K1_ * K2_), K1_, K2_, B / (B1_ * B2_), B1_, B2_>{});
constexpr auto out_6d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerBlock / (K1_ * K2_), 1, K2_, BPerBlock / (B1_ * B2_), 1, B2_>{});
constexpr auto out_kb_global_desc = make_ConstantTensorDescriptor(Sequence<K, B>{});
threadwise_6d_tensor_copy(
out_6d_thread_desc,
p_out_thread,
out_6d_global_desc,
p_out_global +
out_kb_global_desc.Get1dIndex(k_thread_data_begin, b_thread_data_begin),
out_6d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite>{});
}
else
#endif
{ {
for(index_t b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) for(index_t k = 0; k < out_kb_thread_desc.GetLength(I0); ++k)
{ {
const auto c_thread_mtx_distance = for(index_t b = 0; b < out_kb_thread_desc.GetLength(I1); ++b)
blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b); {
const auto c_thread_mtx_distance =
blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b);
index_t k_data = k_thread_data_begin + c_thread_mtx_distance.row; index_t k_data = k_thread_data_begin + c_thread_mtx_distance.row;
index_t b_data = b_thread_data_begin + c_thread_mtx_distance.col; index_t b_data = b_thread_data_begin + c_thread_mtx_distance.col;
index_t h_data = b_data / (Wi * N); index_t h_data = b_data / (Wi * N);
index_t itmp = b_data - h_data * (Wi * N); index_t itmp = b_data - h_data * (Wi * N);
index_t w_data = itmp / N; index_t w_data = itmp / N;
index_t n_data = itmp - w_data * N; index_t n_data = itmp - w_data * N;
if(n_data < N && h_data < Ho && w_data < Wo) if(n_data < N && h_data < Ho && w_data < Wo)
{ {
p_out_global[out_khwn_global_desc.Get1dIndex(k_data, h_data, w_data, n_data)] = p_out_global[out_khwn_global_desc.Get1dIndex(
p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]; k_data, h_data, w_data, n_data)] =
p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)];
}
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment