Commit 766b0a9e authored by Chao Liu's avatar Chao Liu
Browse files

experimenting

parent f35c64eb
#pragma once #pragma once
#include "constant_integral.hip.hpp" #include "constant_integral.hip.hpp"
template <unsigned NLoop> template <index_t NLoop>
struct static_loop_n struct static_loop_n
{ {
template <class F> template <class F>
...@@ -24,7 +24,7 @@ struct static_loop_n<1> ...@@ -24,7 +24,7 @@ struct static_loop_n<1>
} }
}; };
template <unsigned NLoop> template <index_t NLoop>
struct static_const_reduce_n struct static_const_reduce_n
{ {
template <class F, class Reduce> template <class F, class Reduce>
......
...@@ -8,18 +8,18 @@ template <class Float, ...@@ -8,18 +8,18 @@ template <class Float,
class InGlobalDesc, class InGlobalDesc,
class WeiGlobalDesc, class WeiGlobalDesc,
class OutGlobalDesc, class OutGlobalDesc,
unsigned NPerBlock, index_t NPerBlock,
unsigned KPerBlock, index_t KPerBlock,
unsigned CPerBlock, index_t CPerBlock,
unsigned HoPerBlock, index_t HoPerBlock,
unsigned WoPerBlock, index_t WoPerBlock,
unsigned NPerThread, index_t NPerThread,
unsigned KPerThread, index_t KPerThread,
unsigned CPerThread, index_t CPerThread,
unsigned HoPerThread, index_t HoPerThread,
unsigned WoPerThread, index_t WoPerThread,
unsigned BlockSize, index_t BlockSize,
unsigned GridSize> index_t GridSize>
__global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_in_global, __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) Float* const __restrict__ p_out_global)
...@@ -33,16 +33,16 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_ ...@@ -33,16 +33,16 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
constexpr auto wei_global_desc = WeiGlobalDesc{}; constexpr auto wei_global_desc = WeiGlobalDesc{};
constexpr auto out_global_desc = OutGlobalDesc{}; constexpr auto out_global_desc = OutGlobalDesc{};
constexpr unsigned Y = wei_global_desc.GetLength(I2); constexpr index_t Y = wei_global_desc.GetLength(I2);
constexpr unsigned X = wei_global_desc.GetLength(I3); constexpr index_t X = wei_global_desc.GetLength(I3);
constexpr unsigned HiPerBlock = HoPerBlock + Y - 1; constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
constexpr unsigned WiPerBlock = WoPerBlock + X - 1; constexpr index_t WiPerBlock = WoPerBlock + X - 1;
constexpr unsigned NBlockWork = (out_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock; constexpr index_t NBlockWork = (out_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
constexpr unsigned KBlockWork = (out_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock; constexpr index_t KBlockWork = (out_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
constexpr unsigned HBlockWork = (out_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock; constexpr index_t HBlockWork = (out_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
constexpr unsigned WBlockWork = (out_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock; constexpr index_t WBlockWork = (out_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
constexpr auto in_block_global_desc = make_ConstantTensorDescriptor( constexpr auto in_block_global_desc = make_ConstantTensorDescriptor(
Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{}, in_global_desc.GetStrides()); Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{}, in_global_desc.GetStrides());
...@@ -59,31 +59,31 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_ ...@@ -59,31 +59,31 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
constexpr auto out_block_desc = constexpr auto out_block_desc =
make_ConstantTensorDescriptor(out_block_global_desc.GetLengths()); make_ConstantTensorDescriptor(out_block_global_desc.GetLengths());
constexpr unsigned in_block_size = in_block_desc.GetElementSpace(); constexpr index_t in_block_size = in_block_desc.GetElementSpace();
constexpr unsigned wei_block_size = wei_block_desc.GetElementSpace(); constexpr index_t wei_block_size = wei_block_desc.GetElementSpace();
constexpr unsigned out_block_size = out_block_desc.GetElementSpace(); constexpr index_t out_block_size = out_block_desc.GetElementSpace();
__shared__ Float p_in_block[in_block_size]; __shared__ Float p_in_block[in_block_size];
__shared__ Float p_wei_block[wei_block_size]; __shared__ Float p_wei_block[wei_block_size];
__shared__ Float p_out_block[out_block_size]; __shared__ Float p_out_block[out_block_size];
const unsigned block_id = blockIdx.x; const index_t block_id = blockIdx.x;
unsigned itmp = block_id; index_t itmp = block_id;
unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork); index_t n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork); itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork);
unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork); index_t k_block_work_id = itmp / (HBlockWork * WBlockWork);
itmp -= k_block_work_id * (HBlockWork * WBlockWork); itmp -= k_block_work_id * (HBlockWork * WBlockWork);
unsigned h_block_work_id = itmp / WBlockWork; index_t h_block_work_id = itmp / WBlockWork;
unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork; index_t w_block_work_id = itmp - h_block_work_id * WBlockWork;
unsigned n_block_work_begin = n_block_work_id * NPerBlock; index_t n_block_work_begin = n_block_work_id * NPerBlock;
unsigned k_block_work_begin = k_block_work_id * KPerBlock; index_t k_block_work_begin = k_block_work_id * KPerBlock;
unsigned ho_block_work_begin = h_block_work_id * HoPerBlock; index_t ho_block_work_begin = h_block_work_id * HoPerBlock;
unsigned wo_block_work_begin = w_block_work_id * WoPerBlock; index_t wo_block_work_begin = w_block_work_id * WoPerBlock;
unsigned hi_block_work_begin = ho_block_work_begin; // minus padding index_t hi_block_work_begin = ho_block_work_begin; // minus padding
unsigned wi_block_work_begin = wo_block_work_begin; // minus padding index_t wi_block_work_begin = wo_block_work_begin; // minus padding
constexpr auto blockwise_in_copy = constexpr auto blockwise_in_copy =
Blockwise4dTensorCopy1<BlockSize, Blockwise4dTensorCopy1<BlockSize,
...@@ -109,7 +109,7 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_ ...@@ -109,7 +109,7 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
// set output tensor in LDS to 0 // set output tensor in LDS to 0
blockwise_4d_tensor_set_zero<BlockSize>(out_block_desc, p_out_block); blockwise_4d_tensor_set_zero<BlockSize>(out_block_desc, p_out_block);
for(unsigned c_block_work_begin = 0; c_block_work_begin < in_global_desc.GetLength(I1); for(index_t c_block_work_begin = 0; c_block_work_begin < in_global_desc.GetLength(I1);
c_block_work_begin += CPerBlock) c_block_work_begin += CPerBlock)
{ {
// copy input tensor to LDS // copy input tensor to LDS
......
...@@ -11,20 +11,20 @@ template <class Float, ...@@ -11,20 +11,20 @@ template <class Float,
class InGlobalDesc, class InGlobalDesc,
class WeiGlobalDesc, class WeiGlobalDesc,
class OutGlobalDesc, class OutGlobalDesc,
unsigned NPerBlock, index_t NPerBlock,
unsigned KPerBlock, index_t KPerBlock,
unsigned CPerBlock, index_t CPerBlock,
unsigned HoPerBlock, index_t HoPerBlock,
unsigned WoPerBlock, index_t WoPerBlock,
unsigned NPerThread, index_t NPerThread,
unsigned KPerThread, index_t KPerThread,
unsigned CPerThread, index_t CPerThread,
unsigned HoPerThread, index_t HoPerThread,
unsigned WoPerThread, index_t WoPerThread,
unsigned InBlockCopyDataPerRead, index_t InBlockCopyDataPerRead,
unsigned WeiBlockCopyDataPerRead, index_t WeiBlockCopyDataPerRead,
unsigned BlockSize, index_t BlockSize,
unsigned GridSize> index_t GridSize>
__global__ void __global__ void
gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_in_global, gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
...@@ -39,17 +39,17 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -39,17 +39,17 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
constexpr auto wei_kcyx_global_desc = WeiGlobalDesc{}; constexpr auto wei_kcyx_global_desc = WeiGlobalDesc{};
constexpr auto out_nkhw_global_desc = OutGlobalDesc{}; constexpr auto out_nkhw_global_desc = OutGlobalDesc{};
constexpr unsigned N = in_nchw_global_desc.GetLength(I0); constexpr index_t N = in_nchw_global_desc.GetLength(I0);
constexpr unsigned K = wei_kcyx_global_desc.GetLength(I0); constexpr index_t K = wei_kcyx_global_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_global_desc.GetLength(I1); constexpr index_t C = wei_kcyx_global_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_global_desc.GetLength(I2); constexpr index_t Y = wei_kcyx_global_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_global_desc.GetLength(I3); constexpr index_t X = wei_kcyx_global_desc.GetLength(I3);
constexpr auto wei_ke_global_desc = make_ConstantTensorDescriptor( constexpr auto wei_ke_global_desc = make_ConstantTensorDescriptor(
Sequence<K, C * Y * X>{}); // 2d view of wei for blockwise copy Sequence<K, C * Y * X>{}); // 2d view of wei for blockwise copy
constexpr unsigned HiPerBlock = HoPerBlock + Y - 1; constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
constexpr unsigned WiPerBlock = WoPerBlock + X - 1; constexpr index_t WiPerBlock = WoPerBlock + X - 1;
constexpr auto in_nchw_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto in_nchw_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{}, Number<InBlockCopyDataPerRead>{}); Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{}, Number<InBlockCopyDataPerRead>{});
...@@ -63,21 +63,21 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -63,21 +63,21 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
Sequence<wei_ke_block_desc.GetStride(I0), Y * X, X, 1>{}); Sequence<wei_ke_block_desc.GetStride(I0), Y * X, X, 1>{});
// shared mem // shared mem
constexpr unsigned in_block_size = constexpr index_t in_block_size =
in_nchw_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{}); in_nchw_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
constexpr unsigned wei_block_size = constexpr index_t wei_block_size =
wei_kcyx_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{}); wei_kcyx_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
? InBlockCopyDataPerRead ? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead; : WeiBlockCopyDataPerRead;
__shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)]; __shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)];
__shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)]; __shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
// threadwise tensors // threadwise tensors
constexpr unsigned HiPerThread = HoPerThread + Y - 1; constexpr index_t HiPerThread = HoPerThread + Y - 1;
constexpr unsigned WiPerThread = WoPerThread + X - 1; constexpr index_t WiPerThread = WoPerThread + X - 1;
constexpr auto in_nchw_thread_block_desc = constexpr auto in_nchw_thread_block_desc =
make_ConstantTensorDescriptor(Sequence<NPerThread, CPerThread, HiPerThread, WiPerThread>{}, make_ConstantTensorDescriptor(Sequence<NPerThread, CPerThread, HiPerThread, WiPerThread>{},
...@@ -93,56 +93,54 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -93,56 +93,54 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
Float p_out_thread[out_nkhw_thread_desc.GetElementSpace()]; Float p_out_thread[out_nkhw_thread_desc.GetElementSpace()];
// divide block work // divide block work
constexpr unsigned NBlockWork = constexpr index_t NBlockWork = (out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
(out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock; constexpr index_t KBlockWork = (out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
constexpr unsigned KBlockWork = constexpr index_t HBlockWork =
(out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
constexpr unsigned HBlockWork =
(out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock; (out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
constexpr unsigned WBlockWork = constexpr index_t WBlockWork =
(out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock; (out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
const unsigned block_id = blockIdx.x; const index_t block_id = blockIdx.x;
unsigned itmp = block_id; index_t itmp = block_id;
const unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork); const index_t n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork); itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork);
const unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork); const index_t k_block_work_id = itmp / (HBlockWork * WBlockWork);
itmp -= k_block_work_id * (HBlockWork * WBlockWork); itmp -= k_block_work_id * (HBlockWork * WBlockWork);
const unsigned h_block_work_id = itmp / WBlockWork; const index_t h_block_work_id = itmp / WBlockWork;
const unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork; const index_t w_block_work_id = itmp - h_block_work_id * WBlockWork;
const unsigned n_block_data_begin = n_block_work_id * NPerBlock; const index_t n_block_data_begin = n_block_work_id * NPerBlock;
const unsigned k_block_data_begin = k_block_work_id * KPerBlock; const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock; const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock; const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const unsigned hi_block_data_begin = ho_block_data_begin; // minus padding const index_t hi_block_data_begin = ho_block_data_begin; // minus padding
const unsigned wi_block_data_begin = wo_block_data_begin; // minus padding const index_t wi_block_data_begin = wo_block_data_begin; // minus padding
// divide thread work // divide thread work
constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread; constexpr index_t NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
constexpr unsigned KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread; constexpr index_t KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread;
constexpr unsigned HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread; constexpr index_t HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread;
constexpr unsigned WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread; constexpr index_t WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread;
const unsigned thread_id = threadIdx.x; const index_t thread_id = threadIdx.x;
itmp = thread_id; itmp = thread_id;
const unsigned n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork); const index_t n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork);
itmp -= n_thread_work_id * (KThreadWork * HThreadWork * WThreadWork); itmp -= n_thread_work_id * (KThreadWork * HThreadWork * WThreadWork);
const unsigned k_thread_work_id = itmp / (HThreadWork * WThreadWork); const index_t k_thread_work_id = itmp / (HThreadWork * WThreadWork);
itmp -= k_thread_work_id * (HThreadWork * WThreadWork); itmp -= k_thread_work_id * (HThreadWork * WThreadWork);
const unsigned h_thread_work_id = itmp / WThreadWork; const index_t h_thread_work_id = itmp / WThreadWork;
const unsigned w_thread_work_id = itmp - h_thread_work_id * WThreadWork; const index_t w_thread_work_id = itmp - h_thread_work_id * WThreadWork;
const unsigned n_thread_data_begin = n_thread_work_id * NPerThread; const index_t n_thread_data_begin = n_thread_work_id * NPerThread;
const unsigned k_thread_data_begin = k_thread_work_id * KPerThread; const index_t k_thread_data_begin = k_thread_work_id * KPerThread;
const unsigned ho_thread_data_begin = h_thread_work_id * HoPerThread; const index_t ho_thread_data_begin = h_thread_work_id * HoPerThread;
const unsigned wo_thread_data_begin = w_thread_work_id * WoPerThread; const index_t wo_thread_data_begin = w_thread_work_id * WoPerThread;
const unsigned hi_thread_data_begin = ho_thread_data_begin; const index_t hi_thread_data_begin = ho_thread_data_begin;
const unsigned wi_thread_data_begin = wo_thread_data_begin; const index_t wi_thread_data_begin = wo_thread_data_begin;
constexpr auto blockwise_in_copy = constexpr auto blockwise_in_copy =
Blockwise4dTensorCopy1<BlockSize, Blockwise4dTensorCopy1<BlockSize,
...@@ -172,7 +170,7 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -172,7 +170,7 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
// set threadwise output tensor to 0 // set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_nkhw_thread_desc, p_out_thread); threadwise_4d_tensor_set_zero(out_nkhw_thread_desc, p_out_thread);
for(unsigned c_block_data_begin = 0; c_block_data_begin < C; for(index_t c_block_data_begin = 0; c_block_data_begin < C;
c_block_data_begin += CPerBlock, __syncthreads()) c_block_data_begin += CPerBlock, __syncthreads())
{ {
// copy input tensor to LDS // copy input tensor to LDS
...@@ -191,7 +189,7 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -191,7 +189,7 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
__syncthreads(); __syncthreads();
for(unsigned c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread) for(index_t c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread)
{ {
// threadwise convolution // threadwise convolution
#if 1 #if 1
......
...@@ -13,21 +13,21 @@ template <class TInWei, ...@@ -13,21 +13,21 @@ template <class TInWei,
class InGlobalDesc, class InGlobalDesc,
class WeiGlobalDesc, class WeiGlobalDesc,
class OutGlobalDesc, class OutGlobalDesc,
unsigned ScalarPerVector, index_t ScalarPerVector,
unsigned NPerBlock, index_t NPerBlock,
unsigned KPerBlock, index_t KPerBlock,
unsigned CPerBlock, index_t CPerBlock,
unsigned HoPerBlock, index_t HoPerBlock,
unsigned WoPerBlock, index_t WoPerBlock,
unsigned NPerThread, index_t NPerThread,
unsigned KPerThread, index_t KPerThread,
unsigned CPerThread, index_t CPerThread,
unsigned HoPerThread, index_t HoPerThread,
unsigned WoPerThread, index_t WoPerThread,
unsigned InBlockCopyDataPerRead, index_t InBlockCopyDataPerRead,
unsigned WeiBlockCopyDataPerRead, index_t WeiBlockCopyDataPerRead,
unsigned BlockSize, index_t BlockSize,
unsigned GridSize> index_t GridSize>
__global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
const typename vector_type<TInWei, const typename vector_type<TInWei,
ScalarPerVector>::MemoryType* const __restrict__ p_in_vec_global, ScalarPerVector>::MemoryType* const __restrict__ p_in_vec_global,
...@@ -49,17 +49,17 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -49,17 +49,17 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
constexpr auto wei_kcyx_vec_global_desc = WeiGlobalDesc{}; constexpr auto wei_kcyx_vec_global_desc = WeiGlobalDesc{};
constexpr auto out_nkhw_global_desc = OutGlobalDesc{}; constexpr auto out_nkhw_global_desc = OutGlobalDesc{};
constexpr unsigned N = in_nchw_vec_global_desc.GetLength(I0); constexpr index_t N = in_nchw_vec_global_desc.GetLength(I0);
constexpr unsigned K = wei_kcyx_vec_global_desc.GetLength(I0); constexpr index_t K = wei_kcyx_vec_global_desc.GetLength(I0);
constexpr unsigned C = wei_kcyx_vec_global_desc.GetLength(I1); constexpr index_t C = wei_kcyx_vec_global_desc.GetLength(I1);
constexpr unsigned Y = wei_kcyx_vec_global_desc.GetLength(I2); constexpr index_t Y = wei_kcyx_vec_global_desc.GetLength(I2);
constexpr unsigned X = wei_kcyx_vec_global_desc.GetLength(I3); constexpr index_t X = wei_kcyx_vec_global_desc.GetLength(I3);
constexpr auto wei_ke_vec_global_desc = make_ConstantTensorDescriptor( constexpr auto wei_ke_vec_global_desc = make_ConstantTensorDescriptor(
Sequence<K, C * Y * X>{}); // 2d view of wei for blockwise copy Sequence<K, C * Y * X>{}); // 2d view of wei for blockwise copy
constexpr unsigned HiPerBlock = HoPerBlock + Y - 1; constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
constexpr unsigned WiPerBlock = WoPerBlock + X - 1; constexpr index_t WiPerBlock = WoPerBlock + X - 1;
constexpr auto in_nchw_vec_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto in_nchw_vec_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{}, Number<InBlockCopyDataPerRead>{}); Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{}, Number<InBlockCopyDataPerRead>{});
...@@ -73,15 +73,15 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -73,15 +73,15 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
Sequence<wei_ke_vec_block_desc.GetStride(I0), Y * X, X, 1>{}); Sequence<wei_ke_vec_block_desc.GetStride(I0), Y * X, X, 1>{});
// shared mem // shared mem
constexpr unsigned in_block_size = constexpr index_t in_block_size =
in_nchw_vec_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{}); in_nchw_vec_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
constexpr unsigned wei_block_size = constexpr index_t wei_block_size =
wei_kcyx_vec_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{}); wei_kcyx_vec_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
? InBlockCopyDataPerRead ? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead; : WeiBlockCopyDataPerRead;
__shared__ in_vector_mem_t __shared__ in_vector_mem_t
p_in_vec_block[max_align * ((in_block_size + max_align - 1) / max_align)]; p_in_vec_block[max_align * ((in_block_size + max_align - 1) / max_align)];
...@@ -89,8 +89,8 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -89,8 +89,8 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
p_wei_vec_block[max_align * ((wei_block_size + max_align - 1) / max_align)]; p_wei_vec_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
// threadwise tensors // threadwise tensors
constexpr unsigned HiPerThread = HoPerThread + Y - 1; constexpr index_t HiPerThread = HoPerThread + Y - 1;
constexpr unsigned WiPerThread = WoPerThread + X - 1; constexpr index_t WiPerThread = WoPerThread + X - 1;
constexpr auto in_nchw_vec_thread_block_desc = constexpr auto in_nchw_vec_thread_block_desc =
make_ConstantTensorDescriptor(Sequence<NPerThread, CPerThread, HiPerThread, WiPerThread>{}, make_ConstantTensorDescriptor(Sequence<NPerThread, CPerThread, HiPerThread, WiPerThread>{},
...@@ -106,56 +106,54 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -106,56 +106,54 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
out_scalar_t p_out_thread[out_nkhw_thread_desc.GetElementSpace()]; out_scalar_t p_out_thread[out_nkhw_thread_desc.GetElementSpace()];
// divide block work // divide block work
constexpr unsigned NBlockWork = constexpr index_t NBlockWork = (out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
(out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock; constexpr index_t KBlockWork = (out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
constexpr unsigned KBlockWork = constexpr index_t HBlockWork =
(out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
constexpr unsigned HBlockWork =
(out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock; (out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
constexpr unsigned WBlockWork = constexpr index_t WBlockWork =
(out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock; (out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
const unsigned block_id = blockIdx.x; const index_t block_id = blockIdx.x;
unsigned itmp = block_id; index_t itmp = block_id;
const unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork); const index_t n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork); itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork);
const unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork); const index_t k_block_work_id = itmp / (HBlockWork * WBlockWork);
itmp -= k_block_work_id * (HBlockWork * WBlockWork); itmp -= k_block_work_id * (HBlockWork * WBlockWork);
const unsigned h_block_work_id = itmp / WBlockWork; const index_t h_block_work_id = itmp / WBlockWork;
const unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork; const index_t w_block_work_id = itmp - h_block_work_id * WBlockWork;
const unsigned n_block_data_begin = n_block_work_id * NPerBlock; const index_t n_block_data_begin = n_block_work_id * NPerBlock;
const unsigned k_block_data_begin = k_block_work_id * KPerBlock; const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock; const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock; const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const unsigned hi_block_data_begin = ho_block_data_begin; // minus padding const index_t hi_block_data_begin = ho_block_data_begin; // minus padding
const unsigned wi_block_data_begin = wo_block_data_begin; // minus padding const index_t wi_block_data_begin = wo_block_data_begin; // minus padding
// divide thread work // divide thread work
constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread; constexpr index_t NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
constexpr unsigned KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread; constexpr index_t KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread;
constexpr unsigned HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread; constexpr index_t HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread;
constexpr unsigned WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread; constexpr index_t WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread;
const unsigned thread_id = threadIdx.x; const index_t thread_id = threadIdx.x;
itmp = thread_id; itmp = thread_id;
const unsigned n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork); const index_t n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork);
itmp -= n_thread_work_id * (KThreadWork * HThreadWork * WThreadWork); itmp -= n_thread_work_id * (KThreadWork * HThreadWork * WThreadWork);
const unsigned k_thread_work_id = itmp / (HThreadWork * WThreadWork); const index_t k_thread_work_id = itmp / (HThreadWork * WThreadWork);
itmp -= k_thread_work_id * (HThreadWork * WThreadWork); itmp -= k_thread_work_id * (HThreadWork * WThreadWork);
const unsigned h_thread_work_id = itmp / WThreadWork; const index_t h_thread_work_id = itmp / WThreadWork;
const unsigned w_thread_work_id = itmp - h_thread_work_id * WThreadWork; const index_t w_thread_work_id = itmp - h_thread_work_id * WThreadWork;
const unsigned n_thread_data_begin = n_thread_work_id * NPerThread; const index_t n_thread_data_begin = n_thread_work_id * NPerThread;
const unsigned k_thread_data_begin = k_thread_work_id * KPerThread; const index_t k_thread_data_begin = k_thread_work_id * KPerThread;
const unsigned ho_thread_data_begin = h_thread_work_id * HoPerThread; const index_t ho_thread_data_begin = h_thread_work_id * HoPerThread;
const unsigned wo_thread_data_begin = w_thread_work_id * WoPerThread; const index_t wo_thread_data_begin = w_thread_work_id * WoPerThread;
const unsigned hi_thread_data_begin = ho_thread_data_begin; const index_t hi_thread_data_begin = ho_thread_data_begin;
const unsigned wi_thread_data_begin = wo_thread_data_begin; const index_t wi_thread_data_begin = wo_thread_data_begin;
constexpr auto blockwise_in_copy = constexpr auto blockwise_in_copy =
Blockwise4dTensorCopy1<BlockSize, Blockwise4dTensorCopy1<BlockSize,
...@@ -188,7 +186,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -188,7 +186,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
threadwise_4d_tensor_set_zero(out_nkhw_thread_desc, p_out_thread); threadwise_4d_tensor_set_zero(out_nkhw_thread_desc, p_out_thread);
#endif #endif
for(unsigned c_block_data_begin = 0; c_block_data_begin < C; for(index_t c_block_data_begin = 0; c_block_data_begin < C;
c_block_data_begin += CPerBlock, __syncthreads()) c_block_data_begin += CPerBlock, __syncthreads())
{ {
// copy input tensor to LDS // copy input tensor to LDS
...@@ -207,7 +205,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -207,7 +205,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
__syncthreads(); __syncthreads();
for(unsigned c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread) for(index_t c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread)
{ {
// threadwise convolution // threadwise convolution
#if 1 #if 1
......
...@@ -8,32 +8,32 @@ ...@@ -8,32 +8,32 @@
#include "threadwise_4d_tensor_op.hip.hpp" #include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_batched_gemm.hip.hpp" #include "blockwise_batched_gemm.hip.hpp"
template <unsigned GridSize, template <index_t GridSize,
unsigned BlockSize, index_t BlockSize,
class Float, class Float,
class InGlobalDesc, class InGlobalDesc,
class WeiGlobalDesc, class WeiGlobalDesc,
class OutGlobalDesc, class OutGlobalDesc,
unsigned NPerBlock, index_t NPerBlock,
unsigned KPerBlock, index_t KPerBlock,
unsigned CPerBlock, index_t CPerBlock,
unsigned HoPerBlock, index_t HoPerBlock,
unsigned WoPerBlock, index_t WoPerBlock,
unsigned NPerThread, index_t NPerThread,
unsigned KPerThread, index_t KPerThread,
unsigned HoPerThread, index_t HoPerThread,
unsigned WoPerThread, index_t WoPerThread,
class InBlockCopyThreadPerDims, class InBlockCopyThreadPerDims,
unsigned InBlockCopyDataPerRead, index_t InBlockCopyDataPerRead,
unsigned WeiBlockCopyDataPerRead, index_t WeiBlockCopyDataPerRead,
unsigned GemmMPerThreadSubC, index_t GemmMPerThreadSubC,
unsigned GemmNPerThreadSubC, index_t GemmNPerThreadSubC,
unsigned GemmMLevel0Cluster, index_t GemmMLevel0Cluster,
unsigned GemmNLevel0Cluster, index_t GemmNLevel0Cluster,
unsigned GemmMLevel1Cluster, index_t GemmMLevel1Cluster,
unsigned GemmNLevel1Cluster, index_t GemmNLevel1Cluster,
unsigned GemmKPerThreadLoop, index_t GemmKPerThreadLoop,
unsigned OutThreadCopyDataPerWrite> index_t OutThreadCopyDataPerWrite>
__global__ void __global__ void
gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restrict__ p_in_global, gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
...@@ -55,39 +55,39 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric ...@@ -55,39 +55,39 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{}; constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{};
constexpr auto out_khwn_global_desc = OutGlobalDesc{}; constexpr auto out_khwn_global_desc = OutGlobalDesc{};
constexpr unsigned C = in_chwn_global_desc.GetLength(I0); constexpr index_t C = in_chwn_global_desc.GetLength(I0);
constexpr unsigned K = out_khwn_global_desc.GetLength(I0); constexpr index_t K = out_khwn_global_desc.GetLength(I0);
constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1); constexpr index_t Ho = out_khwn_global_desc.GetLength(I1);
constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2); constexpr index_t Wo = out_khwn_global_desc.GetLength(I2);
constexpr unsigned N = out_khwn_global_desc.GetLength(I3); constexpr index_t N = out_khwn_global_desc.GetLength(I3);
constexpr unsigned Y = wei_cyxk_global_desc.GetLength(I1); constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1);
constexpr unsigned X = wei_cyxk_global_desc.GetLength(I2); constexpr index_t X = wei_cyxk_global_desc.GetLength(I2);
constexpr unsigned HiPerBlock = HoPerBlock + Y - 1; constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
constexpr unsigned WiPerBlock = WoPerBlock + X - 1; constexpr index_t WiPerBlock = WoPerBlock + X - 1;
// divide block work: [K, Ho, Wo, N] // divide block work: [K, Ho, Wo, N]
constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr unsigned HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
constexpr unsigned WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
constexpr unsigned NBlockWork = (N + NPerBlock - 1) / NPerBlock; constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
const unsigned k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
unsigned itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
const unsigned h_block_work_id = itmp / (WBlockWork * NBlockWork); const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
itmp -= h_block_work_id * (WBlockWork * NBlockWork); itmp -= h_block_work_id * (WBlockWork * NBlockWork);
const unsigned w_block_work_id = itmp / NBlockWork; const index_t w_block_work_id = itmp / NBlockWork;
const unsigned n_block_work_id = itmp - w_block_work_id * NBlockWork; const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
const unsigned k_block_data_begin = k_block_work_id * KPerBlock; const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock; const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock; const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const unsigned n_block_data_begin = n_block_work_id * NPerBlock; const index_t n_block_data_begin = n_block_work_id * NPerBlock;
const unsigned hi_block_data_begin = ho_block_data_begin; const index_t hi_block_data_begin = ho_block_data_begin;
const unsigned wi_block_data_begin = wo_block_data_begin; const index_t wi_block_data_begin = wo_block_data_begin;
// flattend (2d) tensor view of gridwise weight // flattend (2d) tensor view of gridwise weight
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{}); constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{});
...@@ -164,15 +164,15 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric ...@@ -164,15 +164,15 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
HoPerThread>{}; HoPerThread>{};
// LDS: be careful of alignment // LDS: be careful of alignment
constexpr unsigned in_block_size = constexpr index_t in_block_size =
in_chwn_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{}); in_chwn_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
constexpr unsigned wei_block_size = constexpr index_t wei_block_size =
wei_cyxk_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{}); wei_cyxk_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
? InBlockCopyDataPerRead ? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead; : WeiBlockCopyDataPerRead;
__shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)]; __shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)];
__shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)]; __shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
...@@ -191,10 +191,10 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric ...@@ -191,10 +191,10 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
const Float* p_wei_global_block_begin = const Float* p_wei_global_block_begin =
p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_in_global_block_begin += CPerBlock * in_chwn_global_desc.GetStride(I0), p_in_global_block_begin += CPerBlock * in_chwn_global_desc.GetStride(I0),
p_wei_global_block_begin += CPerBlock * wei_cyxk_global_desc.GetStride(I0), p_wei_global_block_begin += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
__syncthreads()) __syncthreads())
{ {
// input: global mem to LDS // input: global mem to LDS
blockwise_in_copy.Run(p_in_global_block_begin, p_in_block); blockwise_in_copy.Run(p_in_global_block_begin, p_in_block);
...@@ -205,9 +205,9 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric ...@@ -205,9 +205,9 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
__syncthreads(); __syncthreads();
// a series of batched GEMM // a series of batched GEMM
for(unsigned y = 0; y < Y; ++y) for(index_t y = 0; y < Y; ++y)
{ {
for(unsigned x = 0; x < X; ++x) for(index_t x = 0; x < X; ++x)
{ {
#if 0 #if 0
blockwise_batch_gemm.Run blockwise_batch_gemm.Run
...@@ -227,26 +227,26 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric ...@@ -227,26 +227,26 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
const auto c_thread_mtx_begin = const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
for(unsigned k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k) for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
{ {
for(unsigned ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho) for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
{ {
for(unsigned wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo) for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
{ {
for(unsigned n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n) for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
{ {
const unsigned b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n); const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
const auto c_thread_mtx_distance = const auto c_thread_mtx_distance =
blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b); blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b);
const unsigned ho_thread = const index_t ho_thread =
c_thread_mtx_begin.batch + c_thread_mtx_distance.batch; c_thread_mtx_begin.batch + c_thread_mtx_distance.batch;
const unsigned k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row; const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row;
const unsigned b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col; const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
const unsigned wo_thread = b_thread / NPerBlock; const index_t wo_thread = b_thread / NPerBlock;
const unsigned n_thread = b_thread % NPerBlock; const index_t n_thread = b_thread % NPerBlock;
p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread, p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread,
ho_block_data_begin + ho_thread, ho_block_data_begin + ho_thread,
...@@ -261,19 +261,19 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric ...@@ -261,19 +261,19 @@ gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn(const Float* const __restric
const auto c_thread_mtx_begin = const auto c_thread_mtx_begin =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned k_thread_data_begin = c_thread_mtx_begin.row; const index_t k_thread_data_begin = c_thread_mtx_begin.row;
const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch; const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
const unsigned wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const unsigned n_thread_data_begin = c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin; const index_t n_thread_data_begin = c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin;
// this is for v2 GEMM // this is for v2 GEMM
// output is a 8d tensor // output is a 8d tensor
if(NPerThread < NPerBlock && WoPerThread == 1) if(NPerThread < NPerBlock && WoPerThread == 1)
{ {
constexpr unsigned N1_ = GemmNPerThreadSubC; constexpr index_t N1_ = GemmNPerThreadSubC;
constexpr unsigned W1_ = WoPerBlock / ((WoPerThread * NPerThread) / GemmNPerThreadSubC); constexpr index_t W1_ = WoPerBlock / ((WoPerThread * NPerThread) / GemmNPerThreadSubC);
constexpr unsigned K2_ = GemmMPerThreadSubC; constexpr index_t K2_ = GemmMPerThreadSubC;
constexpr unsigned K1_ = KPerBlock / KPerThread; constexpr index_t K1_ = KPerBlock / KPerThread;
constexpr auto out_8d_global_desc = make_ConstantTensorDescriptor( constexpr auto out_8d_global_desc = make_ConstantTensorDescriptor(
Sequence<K / (K1_ * K2_), K1_, K2_, Ho, Wo / W1_, W1_, N / N1_, N1_>{}); Sequence<K / (K1_ * K2_), K1_, K2_, Ho, Wo / W1_, W1_, N / N1_, N1_>{});
......
...@@ -7,26 +7,26 @@ ...@@ -7,26 +7,26 @@
#include "threadwise_4d_tensor_op.hip.hpp" #include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_gemm.hip.hpp" #include "blockwise_gemm.hip.hpp"
template <unsigned GridSize, template <index_t GridSize,
unsigned BlockSize, index_t BlockSize,
class Float, class Float,
class InGlobalDesc, class InGlobalDesc,
class WeiGlobalDesc, class WeiGlobalDesc,
class OutGlobalDesc, class OutGlobalDesc,
class LowerPads, class LowerPads,
class UpperPads, class UpperPads,
unsigned NPerBlock, index_t NPerBlock,
unsigned KPerBlock, index_t KPerBlock,
unsigned CPerBlock, index_t CPerBlock,
unsigned HoPerBlock, index_t HoPerBlock,
unsigned WoPerBlock, index_t WoPerBlock,
unsigned NPerThread, index_t NPerThread,
unsigned KPerThread, index_t KPerThread,
unsigned CPerThread, index_t CPerThread,
unsigned HoPerThread, index_t HoPerThread,
unsigned WoPerThread, index_t WoPerThread,
unsigned WeiBlockCopyThreadPerDim0, index_t WeiBlockCopyThreadPerDim0,
unsigned WeiBlockCopyThreadPerDim1> index_t WeiBlockCopyThreadPerDim1>
__global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded( __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
const Float* const __restrict__ p_in_global, const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
...@@ -48,42 +48,42 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded( ...@@ -48,42 +48,42 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{}; constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{};
constexpr auto out_khwn_global_desc = OutGlobalDesc{}; constexpr auto out_khwn_global_desc = OutGlobalDesc{};
constexpr unsigned C = in_chwn_global_desc.GetLength(I0); constexpr index_t C = in_chwn_global_desc.GetLength(I0);
constexpr unsigned K = out_khwn_global_desc.GetLength(I0); constexpr index_t K = out_khwn_global_desc.GetLength(I0);
constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1); constexpr index_t Ho = out_khwn_global_desc.GetLength(I1);
constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2); constexpr index_t Wo = out_khwn_global_desc.GetLength(I2);
constexpr unsigned N = out_khwn_global_desc.GetLength(I3); constexpr index_t N = out_khwn_global_desc.GetLength(I3);
constexpr unsigned Y = wei_cyxk_global_desc.GetLength(I1); constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1);
constexpr unsigned X = wei_cyxk_global_desc.GetLength(I2); constexpr index_t X = wei_cyxk_global_desc.GetLength(I2);
constexpr unsigned HPadLow = LowerPads{}.Get(I0); constexpr index_t HPadLow = LowerPads{}.Get(I0);
constexpr unsigned WPadLow = LowerPads{}.Get(I1); constexpr index_t WPadLow = LowerPads{}.Get(I1);
constexpr unsigned HPadUp = UpperPads{}.Get(I0); constexpr index_t HPadUp = UpperPads{}.Get(I0);
constexpr unsigned WPadUp = UpperPads{}.Get(I1); constexpr index_t WPadUp = UpperPads{}.Get(I1);
constexpr unsigned HiPerBlock = HoPerBlock + Y - 1; constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
constexpr unsigned WiPerBlock = WoPerBlock + X - 1; constexpr index_t WiPerBlock = WoPerBlock + X - 1;
// divide block work: [K, Ho, Wo, N] // divide block work: [K, Ho, Wo, N]
constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr unsigned HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
constexpr unsigned WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
constexpr unsigned NBlockWork = (N + NPerBlock - 1) / NPerBlock; constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
const unsigned k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
unsigned itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
const unsigned h_block_work_id = itmp / (WBlockWork * NBlockWork); const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
itmp -= h_block_work_id * (WBlockWork * NBlockWork); itmp -= h_block_work_id * (WBlockWork * NBlockWork);
const unsigned w_block_work_id = itmp / NBlockWork; const index_t w_block_work_id = itmp / NBlockWork;
const unsigned n_block_work_id = itmp - w_block_work_id * NBlockWork; const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
const unsigned k_block_data_begin = k_block_work_id * KPerBlock; const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock; const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock; const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
const unsigned n_block_data_begin = n_block_work_id * NPerBlock; const index_t n_block_data_begin = n_block_work_id * NPerBlock;
// flattened (2d) tensor view of wei in global mem // flattened (2d) tensor view of wei in global mem
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{}); constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{});
...@@ -114,11 +114,11 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded( ...@@ -114,11 +114,11 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
// blockwise copy // blockwise copy
// input: format is [C, Hi, Wi, N] // input: format is [C, Hi, Wi, N]
const unsigned h_block_pad_low = h_block_work_id == 0 ? HPadLow : 0; const index_t h_block_pad_low = h_block_work_id == 0 ? HPadLow : 0;
const unsigned w_block_pad_low = w_block_work_id == 0 ? WPadLow : 0; const index_t w_block_pad_low = w_block_work_id == 0 ? WPadLow : 0;
const unsigned h_block_pad_up = h_block_work_id == HBlockWork - 1 ? HPadUp : 0; const index_t h_block_pad_up = h_block_work_id == HBlockWork - 1 ? HPadUp : 0;
const unsigned w_block_pad_up = w_block_work_id == WBlockWork - 1 ? WPadUp : 0; const index_t w_block_pad_up = w_block_work_id == WBlockWork - 1 ? WPadUp : 0;
#if 0 #if 0
if(get_thread_local_1d_id() == 0) if(get_thread_local_1d_id() == 0)
...@@ -204,8 +204,8 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded( ...@@ -204,8 +204,8 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
true>{}; true>{};
// LDS // LDS
constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace(); constexpr index_t in_block_size = in_chwn_block_desc.GetElementSpace();
constexpr unsigned wei_block_size = wei_cyxk_block_desc.GetElementSpace(); constexpr index_t wei_block_size = wei_cyxk_block_desc.GetElementSpace();
__shared__ Float p_in_block[in_block_size]; __shared__ Float p_in_block[in_block_size];
__shared__ Float p_wei_block[wei_block_size]; __shared__ Float p_wei_block[wei_block_size];
...@@ -219,9 +219,9 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded( ...@@ -219,9 +219,9 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
const Float* p_wei_global_block_begin = const Float* p_wei_global_block_begin =
p_wei_global + wei_ek_global_desc.Get1dIndex(0, k_block_data_begin); p_wei_global + wei_ek_global_desc.Get1dIndex(0, k_block_data_begin);
for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_wei_global_block_begin += CPerBlock * wei_ek_global_desc.GetStride(I0), p_wei_global_block_begin += CPerBlock * wei_ek_global_desc.GetStride(I0),
__syncthreads()) __syncthreads())
{ {
#if 1 #if 1
// input: global mem to LDS, // input: global mem to LDS,
...@@ -245,9 +245,9 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded( ...@@ -245,9 +245,9 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
__syncthreads(); __syncthreads();
// a series of batched GEMM // a series of batched GEMM
for(unsigned y = 0; y < Y; ++y) for(index_t y = 0; y < Y; ++y)
{ {
for(unsigned x = 0; x < X; ++x) for(index_t x = 0; x < X; ++x)
{ {
auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
...@@ -262,10 +262,10 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded( ...@@ -262,10 +262,10 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
const auto matrix_c_index = const auto matrix_c_index =
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned ho_thread_data_begin = matrix_c_index.batch; const index_t ho_thread_data_begin = matrix_c_index.batch;
const unsigned k_thread_data_begin = matrix_c_index.row; const index_t k_thread_data_begin = matrix_c_index.row;
const unsigned wo_thread_data_begin = matrix_c_index.col / NPerBlock; const index_t wo_thread_data_begin = matrix_c_index.col / NPerBlock;
const unsigned n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock; const index_t n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock;
#if 0 #if 0
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n", printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",
......
...@@ -8,32 +8,32 @@ ...@@ -8,32 +8,32 @@
#include "blockwise_gemm.hip.hpp" #include "blockwise_gemm.hip.hpp"
// define B = flatten(N, Hi, Wi) // define B = flatten(N, Hi, Wi)
template <unsigned GridSize, template <index_t GridSize,
unsigned BlockSize, index_t BlockSize,
class Float, class Float,
class InGlobalDesc, class InGlobalDesc,
class WeiGlobalDesc, class WeiGlobalDesc,
class OutGlobalDesc, class OutGlobalDesc,
unsigned BPerBlock, index_t BPerBlock,
unsigned KPerBlock, index_t KPerBlock,
unsigned CPerBlock, index_t CPerBlock,
unsigned BPerThread, index_t BPerThread,
unsigned KPerThread, index_t KPerThread,
unsigned GemmThreadPerColumnPerCluster, index_t GemmThreadPerColumnPerCluster,
unsigned GemmThreadPerRowPerCluster, index_t GemmThreadPerRowPerCluster,
unsigned GemmMPerThreadSubC, index_t GemmMPerThreadSubC,
unsigned GemmNPerThreadSubC, index_t GemmNPerThreadSubC,
unsigned GemmMLevel0Cluster, index_t GemmMLevel0Cluster,
unsigned GemmNLevel0Cluster, index_t GemmNLevel0Cluster,
unsigned GemmMLevel1Cluster, index_t GemmMLevel1Cluster,
unsigned GemmNLevel1Cluster, index_t GemmNLevel1Cluster,
unsigned GemmKPerThreadLoop, index_t GemmKPerThreadLoop,
unsigned InBlockCopyThreadPerDim0, index_t InBlockCopyThreadPerDim0,
unsigned InBlockCopyThreadPerDim1, index_t InBlockCopyThreadPerDim1,
unsigned WeiBlockCopyThreadPerDim0, index_t WeiBlockCopyThreadPerDim0,
unsigned WeiBlockCopyThreadPerDim1, index_t WeiBlockCopyThreadPerDim1,
unsigned InBlockCopyDataPerRead, index_t InBlockCopyDataPerRead,
unsigned WeiBlockCopyDataPerRead> index_t WeiBlockCopyDataPerRead>
__global__ void __global__ void
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restrict__ p_in_global, gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
...@@ -48,30 +48,30 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric ...@@ -48,30 +48,30 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{}; constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{};
constexpr auto out_khwn_global_desc = OutGlobalDesc{}; constexpr auto out_khwn_global_desc = OutGlobalDesc{};
constexpr unsigned C = in_chwn_global_desc.GetLength(I0); constexpr index_t C = in_chwn_global_desc.GetLength(I0);
constexpr unsigned Hi = in_chwn_global_desc.GetLength(I1); constexpr index_t Hi = in_chwn_global_desc.GetLength(I1);
constexpr unsigned Wi = in_chwn_global_desc.GetLength(I2); constexpr index_t Wi = in_chwn_global_desc.GetLength(I2);
constexpr unsigned N = in_chwn_global_desc.GetLength(I3); constexpr index_t N = in_chwn_global_desc.GetLength(I3);
constexpr unsigned K = out_khwn_global_desc.GetLength(I0); constexpr index_t K = out_khwn_global_desc.GetLength(I0);
constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1); constexpr index_t Ho = out_khwn_global_desc.GetLength(I1);
constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2); constexpr index_t Wo = out_khwn_global_desc.GetLength(I2);
constexpr unsigned Y = wei_cyxk_global_desc.GetLength(I1); constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1);
constexpr unsigned X = wei_cyxk_global_desc.GetLength(I2); constexpr index_t X = wei_cyxk_global_desc.GetLength(I2);
constexpr unsigned B = N * Hi * Wi; constexpr index_t B = N * Hi * Wi;
constexpr unsigned BGhostRead = (Y - 1) * Wi + (X - 1); constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1);
// divide block work by 2d: [K, B] // divide block work by 2d: [K, B]
constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr unsigned BBlockWork = (B + BPerBlock - 1) / BPerBlock; constexpr index_t BBlockWork = (B + BPerBlock - 1) / BPerBlock;
const unsigned k_block_work_id = get_block_1d_id() / BBlockWork; const index_t k_block_work_id = get_block_1d_id() / BBlockWork;
const unsigned b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork; const index_t b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork;
const unsigned k_block_data_begin = k_block_work_id * KPerBlock; const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const unsigned b_block_data_begin = b_block_work_id * BPerBlock; const index_t b_block_data_begin = b_block_work_id * BPerBlock;
// flattend (2d) tensor view of gridwise input // flattend (2d) tensor view of gridwise input
constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence<C, B>{}); constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence<C, B>{});
...@@ -192,15 +192,15 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric ...@@ -192,15 +192,15 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
GemmKPerThreadLoop>{}; GemmKPerThreadLoop>{};
// LDS: be careful of alignment // LDS: be careful of alignment
constexpr unsigned in_block_size = constexpr index_t in_block_size =
in_cb_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{}); in_cb_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
constexpr unsigned wei_block_size = constexpr index_t wei_block_size =
wei_cyxk_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{}); wei_cyxk_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
? InBlockCopyDataPerRead ? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead; : WeiBlockCopyDataPerRead;
// LDS // LDS
__shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)]; __shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)];
...@@ -218,10 +218,10 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric ...@@ -218,10 +218,10 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
// set threadwise output tensor to 0 // set threadwise output tensor to 0
threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread); threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread);
for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0), p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0),
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
__syncthreads()) __syncthreads())
{ {
// load data // load data
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
...@@ -231,18 +231,16 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric ...@@ -231,18 +231,16 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
// compute on current data // compute on current data
// a series of GEMM // a series of GEMM
for(unsigned y = 0; y < Y; ++y) for(index_t y = 0; y < Y; ++y)
{ {
for(unsigned x = 0; x < X; ++x) for(index_t x = 0; x < X; ++x)
{ {
auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
#if 0 #if 0
blockwise_gemm.Run blockwise_gemm.Run
#elif 1
blockwise_gemm.Run_asm
#elif 0
blockwise_gemm.Run_v2
#elif 0 #elif 0
blockwise_gemm.Run_asm
#elif 1
blockwise_gemm.Run_RegisterDoubleBuffer blockwise_gemm.Run_RegisterDoubleBuffer
#endif #endif
(p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), (p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
...@@ -257,23 +255,23 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric ...@@ -257,23 +255,23 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
const auto c_thread_mtx_begin = const auto c_thread_mtx_begin =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row; const index_t k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row;
const unsigned b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col; const index_t b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col;
for(unsigned k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) for(index_t k = 0; k < out_kb_thread_desc.GetLength(I0); ++k)
{ {
for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) for(index_t b = 0; b < out_kb_thread_desc.GetLength(I1); ++b)
{ {
const auto c_thread_mtx_distance = const auto c_thread_mtx_distance =
blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b); blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b);
unsigned k_data = k_thread_data_begin + c_thread_mtx_distance.row; index_t k_data = k_thread_data_begin + c_thread_mtx_distance.row;
unsigned b_data = b_thread_data_begin + c_thread_mtx_distance.col; index_t b_data = b_thread_data_begin + c_thread_mtx_distance.col;
unsigned h_data = b_data / (Wi * N); index_t h_data = b_data / (Wi * N);
unsigned itmp = b_data - h_data * (Wi * N); index_t itmp = b_data - h_data * (Wi * N);
unsigned w_data = itmp / N; index_t w_data = itmp / N;
unsigned n_data = itmp - w_data * N; index_t n_data = itmp - w_data * N;
if(n_data < N && h_data < Ho && w_data < Wo) if(n_data < N && h_data < Ho && w_data < Wo)
{ {
......
...@@ -8,32 +8,32 @@ ...@@ -8,32 +8,32 @@
#include "blockwise_gemm.hip.hpp" #include "blockwise_gemm.hip.hpp"
// define B = flatten(N, Hi, Wi) // define B = flatten(N, Hi, Wi)
template <unsigned GridSize, template <index_t GridSize,
unsigned BlockSize, index_t BlockSize,
class Float, class Float,
class InGlobalDesc, class InGlobalDesc,
class WeiGlobalDesc, class WeiGlobalDesc,
class OutGlobalDesc, class OutGlobalDesc,
unsigned BPerBlock, index_t BPerBlock,
unsigned KPerBlock, index_t KPerBlock,
unsigned CPerBlock, index_t CPerBlock,
unsigned BPerThread, index_t BPerThread,
unsigned KPerThread, index_t KPerThread,
unsigned GemmThreadPerColumnPerCluster, index_t GemmThreadPerColumnPerCluster,
unsigned GemmThreadPerRowPerCluster, index_t GemmThreadPerRowPerCluster,
unsigned GemmMPerThreadSubC, index_t GemmMPerThreadSubC,
unsigned GemmNPerThreadSubC, index_t GemmNPerThreadSubC,
unsigned GemmMLevel0Cluster, index_t GemmMLevel0Cluster,
unsigned GemmNLevel0Cluster, index_t GemmNLevel0Cluster,
unsigned GemmMLevel1Cluster, index_t GemmMLevel1Cluster,
unsigned GemmNLevel1Cluster, index_t GemmNLevel1Cluster,
unsigned GemmKPerThreadLoop, index_t GemmKPerThreadLoop,
unsigned InBlockCopyThreadPerDim0, index_t InBlockCopyThreadPerDim0,
unsigned InBlockCopyThreadPerDim1, index_t InBlockCopyThreadPerDim1,
unsigned WeiBlockCopyThreadPerDim0, index_t WeiBlockCopyThreadPerDim0,
unsigned WeiBlockCopyThreadPerDim1, index_t WeiBlockCopyThreadPerDim1,
unsigned InBlockCopyDataPerRead, index_t InBlockCopyDataPerRead,
unsigned WeiBlockCopyDataPerRead> index_t WeiBlockCopyDataPerRead>
__global__ void __global__ void
#if 0 #if 0
__launch_bounds__(256,2) __launch_bounds__(256,2)
...@@ -52,30 +52,30 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( ...@@ -52,30 +52,30 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{}; constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{};
constexpr auto out_khwn_global_desc = OutGlobalDesc{}; constexpr auto out_khwn_global_desc = OutGlobalDesc{};
constexpr unsigned C = in_chwn_global_desc.GetLength(I0); constexpr index_t C = in_chwn_global_desc.GetLength(I0);
constexpr unsigned Hi = in_chwn_global_desc.GetLength(I1); constexpr index_t Hi = in_chwn_global_desc.GetLength(I1);
constexpr unsigned Wi = in_chwn_global_desc.GetLength(I2); constexpr index_t Wi = in_chwn_global_desc.GetLength(I2);
constexpr unsigned N = in_chwn_global_desc.GetLength(I3); constexpr index_t N = in_chwn_global_desc.GetLength(I3);
constexpr unsigned K = out_khwn_global_desc.GetLength(I0); constexpr index_t K = out_khwn_global_desc.GetLength(I0);
constexpr unsigned Ho = out_khwn_global_desc.GetLength(I1); constexpr index_t Ho = out_khwn_global_desc.GetLength(I1);
constexpr unsigned Wo = out_khwn_global_desc.GetLength(I2); constexpr index_t Wo = out_khwn_global_desc.GetLength(I2);
constexpr unsigned Y = wei_cyxk_global_desc.GetLength(I1); constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1);
constexpr unsigned X = wei_cyxk_global_desc.GetLength(I2); constexpr index_t X = wei_cyxk_global_desc.GetLength(I2);
constexpr unsigned B = N * Hi * Wi; constexpr index_t B = N * Hi * Wi;
constexpr unsigned BGhostRead = (Y - 1) * Wi + (X - 1); constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1);
// divide block work by 2d: [K, B] // divide block work by 2d: [K, B]
constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock; constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
constexpr unsigned BBlockWork = (B + BPerBlock - 1) / BPerBlock; constexpr index_t BBlockWork = (B + BPerBlock - 1) / BPerBlock;
const unsigned k_block_work_id = get_block_1d_id() / BBlockWork; const index_t k_block_work_id = get_block_1d_id() / BBlockWork;
const unsigned b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork; const index_t b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork;
const unsigned k_block_data_begin = k_block_work_id * KPerBlock; const index_t k_block_data_begin = k_block_work_id * KPerBlock;
const unsigned b_block_data_begin = b_block_work_id * BPerBlock; const index_t b_block_data_begin = b_block_work_id * BPerBlock;
// flattend (2d) tensor view of gridwise input // flattend (2d) tensor view of gridwise input
constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence<C, B>{}); constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence<C, B>{});
...@@ -210,15 +210,15 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( ...@@ -210,15 +210,15 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
#endif #endif
// LDS: be careful of alignment // LDS: be careful of alignment
constexpr unsigned in_block_size = constexpr index_t in_block_size =
in_cb_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{}); in_cb_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
constexpr unsigned wei_block_size = constexpr index_t wei_block_size =
wei_cyxk_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{}); wei_cyxk_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
? InBlockCopyDataPerRead ? InBlockCopyDataPerRead
: WeiBlockCopyDataPerRead; : WeiBlockCopyDataPerRead;
// LDS double buffer // LDS double buffer
__shared__ Float p_in_block_0[max_align * ((in_block_size + max_align - 1) / max_align)]; __shared__ Float p_in_block_0[max_align * ((in_block_size + max_align - 1) / max_align)];
...@@ -248,11 +248,11 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( ...@@ -248,11 +248,11 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
bool even_loop = true; bool even_loop = true;
for(unsigned c_block_data_begin = 0; c_block_data_begin + CPerBlock < C; for(index_t c_block_data_begin = 0; c_block_data_begin + CPerBlock < C;
c_block_data_begin += CPerBlock, c_block_data_begin += CPerBlock,
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0), p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0),
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
even_loop = !even_loop) even_loop = !even_loop)
{ {
Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1; Float* p_in_block_now = even_loop ? p_in_block_0 : p_in_block_1;
Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1; Float* p_wei_block_now = even_loop ? p_wei_block_0 : p_wei_block_1;
...@@ -279,12 +279,12 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( ...@@ -279,12 +279,12 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
// compute on current data // compute on current data
// a series of GEMM // a series of GEMM
for(unsigned y = 0; y < Y; ++y) for(index_t y = 0; y < Y; ++y)
{ {
for(unsigned x = 0; x < X; ++x) for(index_t x = 0; x < X; ++x)
{ {
auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
#if 0 #if 1
blockwise_gemm.Run blockwise_gemm.Run
#else #else
blockwise_gemm.Run_RegisterDoubleBuffer blockwise_gemm.Run_RegisterDoubleBuffer
...@@ -309,12 +309,12 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( ...@@ -309,12 +309,12 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
__syncthreads(); __syncthreads();
for(unsigned y = 0; y < Y; ++y) for(index_t y = 0; y < Y; ++y)
{ {
for(unsigned x = 0; x < X; ++x) for(index_t x = 0; x < X; ++x)
{ {
auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
#if 0 #if 1
blockwise_gemm.Run blockwise_gemm.Run
#else #else
blockwise_gemm.Run_RegisterDoubleBuffer blockwise_gemm.Run_RegisterDoubleBuffer
...@@ -331,8 +331,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( ...@@ -331,8 +331,8 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
const auto c_thread_mtx_begin = const auto c_thread_mtx_begin =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const unsigned k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row; const index_t k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row;
const unsigned b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col; const index_t b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col;
#if 0 #if 0
if(get_block_1d_id() == 0) if(get_block_1d_id() == 0)
...@@ -348,20 +348,20 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( ...@@ -348,20 +348,20 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
} }
#endif #endif
for(unsigned k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) for(index_t k = 0; k < out_kb_thread_desc.GetLength(I0); ++k)
{ {
for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) for(index_t b = 0; b < out_kb_thread_desc.GetLength(I1); ++b)
{ {
const auto c_thread_mtx_distance = const auto c_thread_mtx_distance =
blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b); blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b);
unsigned k_data = k_thread_data_begin + c_thread_mtx_distance.row; index_t k_data = k_thread_data_begin + c_thread_mtx_distance.row;
unsigned b_data = b_thread_data_begin + c_thread_mtx_distance.col; index_t b_data = b_thread_data_begin + c_thread_mtx_distance.col;
unsigned h_data = b_data / (Wi * N); index_t h_data = b_data / (Wi * N);
unsigned itmp = b_data - h_data * (Wi * N); index_t itmp = b_data - h_data * (Wi * N);
unsigned w_data = itmp / N; index_t w_data = itmp / N;
unsigned n_data = itmp - w_data * N; index_t n_data = itmp - w_data * N;
if(n_data < N && h_data < Ho && w_data < Wo) if(n_data < N && h_data < Ho && w_data < Wo)
{ {
......
...@@ -16,11 +16,11 @@ __device__ void threadwise_2d_tensor_pointwise_operation_unary(Desc, Float* __re ...@@ -16,11 +16,11 @@ __device__ void threadwise_2d_tensor_pointwise_operation_unary(Desc, Float* __re
} }
#endif #endif
for(unsigned did0 = 0; did0 < desc.GetLength(I0); ++did0) for(index_t did0 = 0; did0 < desc.GetLength(I0); ++did0)
{ {
for(unsigned did1 = 0; did1 < desc.GetLength(I1); ++did1) for(index_t did1 = 0; did1 < desc.GetLength(I1); ++did1)
{ {
const unsigned dindex = desc.Get1dIndex(did0, did1); const index_t dindex = desc.Get1dIndex(did0, did1);
f(p[dindex]); f(p[dindex]);
} }
...@@ -47,22 +47,22 @@ __device__ void threadwise_2d_tensor_pointwise_operation_binary_reorder_by_get_d ...@@ -47,22 +47,22 @@ __device__ void threadwise_2d_tensor_pointwise_operation_binary_reorder_by_get_d
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0); constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0);
constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1); constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1);
constexpr auto src_desc = SrcDesc{}; constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{}; constexpr auto dst_desc = DstDesc{};
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{}); constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0) for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
{ {
for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1) for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
{ {
const unsigned aindex = src_desc.Get1dIndex(did0, did1); const index_t aindex = src_desc.Get1dIndex(did0, did1);
const unsigned did[2] = {did0, did1}; const index_t did[2] = {did0, did1};
const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]); const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
f(p_src[aindex], p_dst[bindex]); f(p_src[aindex], p_dst[bindex]);
} }
...@@ -118,21 +118,21 @@ __device__ void threadwise_2d_tensor_shift_down(Desc, Float* __restrict__ p, IDi ...@@ -118,21 +118,21 @@ __device__ void threadwise_2d_tensor_shift_down(Desc, Float* __restrict__ p, IDi
} }
#endif #endif
constexpr unsigned nshift = NShift::mValue; constexpr index_t nshift = NShift::mValue;
constexpr unsigned did0_end = constexpr index_t did0_end =
is_same<decltype(I0), IDim>::value ? desc.GetLength(I0) - nshift : desc.GetLength(I0); is_same<decltype(I0), IDim>::value ? desc.GetLength(I0) - nshift : desc.GetLength(I0);
constexpr unsigned did1_end = constexpr index_t did1_end =
is_same<decltype(I1), IDim>::value ? desc.GetLength(I1) - nshift : desc.GetLength(I1); is_same<decltype(I1), IDim>::value ? desc.GetLength(I1) - nshift : desc.GetLength(I1);
for(unsigned did0 = 0; did0 < did0_end; ++did0) for(index_t did0 = 0; did0 < did0_end; ++did0)
{ {
for(unsigned did1 = 0; did1 < did1_end; ++did1) for(index_t did1 = 0; did1 < did1_end; ++did1)
{ {
const unsigned dindex = desc.Get1dIndex(did0, did1); const index_t dindex = desc.Get1dIndex(did0, did1);
const unsigned sindex = dindex + nshift * desc.GetStride(IDim{}); const index_t sindex = dindex + nshift * desc.GetStride(IDim{});
p[dindex] = p[sindex]; p[dindex] = p[sindex];
} }
......
...@@ -18,15 +18,15 @@ __device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __re ...@@ -18,15 +18,15 @@ __device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __re
} }
#endif #endif
for(unsigned did0 = 0; did0 < desc.GetLength(I0); ++did0) for(index_t did0 = 0; did0 < desc.GetLength(I0); ++did0)
{ {
for(unsigned did1 = 0; did1 < desc.GetLength(I1); ++did1) for(index_t did1 = 0; did1 < desc.GetLength(I1); ++did1)
{ {
for(unsigned did2 = 0; did2 < desc.GetLength(I2); ++did2) for(index_t did2 = 0; did2 < desc.GetLength(I2); ++did2)
{ {
for(unsigned did3 = 0; did3 < desc.GetLength(I3); ++did3) for(index_t did3 = 0; did3 < desc.GetLength(I3); ++did3)
{ {
const unsigned dindex = desc.Get1dIndex(did0, did1, did2, did3); const index_t dindex = desc.Get1dIndex(did0, did1, did2, did3);
f(p[dindex]); f(p[dindex]);
} }
...@@ -58,28 +58,28 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d ...@@ -58,28 +58,28 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0); constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0);
constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1); constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1);
constexpr unsigned IR2 = DstFromSrcReorder{}.Get(I2); constexpr index_t IR2 = DstFromSrcReorder{}.Get(I2);
constexpr unsigned IR3 = DstFromSrcReorder{}.Get(I3); constexpr index_t IR3 = DstFromSrcReorder{}.Get(I3);
constexpr auto src_desc = SrcDesc{}; constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{}; constexpr auto dst_desc = DstDesc{};
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{}); constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0) for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
{ {
for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1) for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
{ {
for(unsigned did2 = 0; did2 < ref_desc.GetLength(I2); ++did2) for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
{ {
for(unsigned did3 = 0; did3 < ref_desc.GetLength(I3); ++did3) for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
{ {
const unsigned aindex = src_desc.Get1dIndex(did0, did1, did2, did3); const index_t aindex = src_desc.Get1dIndex(did0, did1, did2, did3);
const unsigned did[4] = {did0, did1, did2, did3}; const index_t did[4] = {did0, did1, did2, did3};
const unsigned bindex = const index_t bindex =
dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]); dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
f(p_src[aindex], p_dst[bindex]); f(p_src[aindex], p_dst[bindex]);
...@@ -129,7 +129,7 @@ __device__ void threadwise_4d_tensor_copy( ...@@ -129,7 +129,7 @@ __device__ void threadwise_4d_tensor_copy(
} }
// need to assume src and dst is aligned // need to assume src and dst is aligned
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, unsigned DataPerRead> template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
__device__ void threadwise_4d_tensor_copy_v2(SrcDesc, __device__ void threadwise_4d_tensor_copy_v2(SrcDesc,
const Float* __restrict__ p_src, const Float* __restrict__ p_src,
DstDesc, DstDesc,
...@@ -163,24 +163,24 @@ __device__ void threadwise_4d_tensor_copy_v2(SrcDesc, ...@@ -163,24 +163,24 @@ __device__ void threadwise_4d_tensor_copy_v2(SrcDesc,
DstDesc{}.GetStride(I2) % DataPerRead == 0, DstDesc{}.GetStride(I2) % DataPerRead == 0,
"wrong! src and dst stride should be multiple of DataPerRead to keep alignment"); "wrong! src and dst stride should be multiple of DataPerRead to keep alignment");
constexpr unsigned L3 = SrcOpLengths{}.Get(I3); constexpr index_t L3 = SrcOpLengths{}.Get(I3);
static_assert(L3 % DataPerRead == 0, "wrong! L3 should be evenly divided by DataPerRead"); static_assert(L3 % DataPerRead == 0, "wrong! L3 should be evenly divided by DataPerRead");
constexpr unsigned nloop_d3 = L3 / DataPerRead; constexpr index_t nloop_d3 = L3 / DataPerRead;
for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0) for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
{ {
for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1) for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
{ {
for(unsigned did2 = 0; did2 < ref_desc.GetLength(I2); ++did2) for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
{ {
for(unsigned iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3) for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
{ {
const unsigned src_index = const index_t src_index =
src_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead); src_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead);
const unsigned dst_index = const index_t dst_index =
dst_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead); dst_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead);
if(DataPerRead == 1) if(DataPerRead == 1)
...@@ -224,31 +224,31 @@ __device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDi ...@@ -224,31 +224,31 @@ __device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDi
} }
#endif #endif
constexpr unsigned nshift = NShift::mValue; constexpr index_t nshift = NShift::mValue;
constexpr unsigned did0_end = constexpr index_t did0_end =
is_same<decltype(I0), IDim>::value ? desc.GetLength(I0) - nshift : desc.GetLength(I0); is_same<decltype(I0), IDim>::value ? desc.GetLength(I0) - nshift : desc.GetLength(I0);
constexpr unsigned did1_end = constexpr index_t did1_end =
is_same<decltype(I1), IDim>::value ? desc.GetLength(I1) - nshift : desc.GetLength(I1); is_same<decltype(I1), IDim>::value ? desc.GetLength(I1) - nshift : desc.GetLength(I1);
constexpr unsigned did2_end = constexpr index_t did2_end =
is_same<decltype(I2), IDim>::value ? desc.GetLength(I2) - nshift : desc.GetLength(I2); is_same<decltype(I2), IDim>::value ? desc.GetLength(I2) - nshift : desc.GetLength(I2);
constexpr unsigned did3_end = constexpr index_t did3_end =
is_same<decltype(I3), IDim>::value ? desc.GetLength(I3) - nshift : desc.GetLength(I3); is_same<decltype(I3), IDim>::value ? desc.GetLength(I3) - nshift : desc.GetLength(I3);
for(unsigned did0 = 0; did0 < did0_end; ++did0) for(index_t did0 = 0; did0 < did0_end; ++did0)
{ {
for(unsigned did1 = 0; did1 < did1_end; ++did1) for(index_t did1 = 0; did1 < did1_end; ++did1)
{ {
for(unsigned did2 = 0; did2 < did2_end; ++did2) for(index_t did2 = 0; did2 < did2_end; ++did2)
{ {
for(unsigned did3 = 0; did3 < did3_end; ++did3) for(index_t did3 = 0; did3 < did3_end; ++did3)
{ {
const unsigned dindex = desc.Get1dIndex(did0, did1, did2, did3); const index_t dindex = desc.Get1dIndex(did0, did1, did2, did3);
const unsigned sindex = dindex + nshift * desc.GetStride(IDim{}); const index_t sindex = dindex + nshift * desc.GetStride(IDim{});
p[dindex] = p[sindex]; p[dindex] = p[sindex];
} }
......
...@@ -28,28 +28,28 @@ __device__ void threadwise_direct_convolution_1(InDesc, ...@@ -28,28 +28,28 @@ __device__ void threadwise_direct_convolution_1(InDesc,
} }
#endif #endif
for(unsigned n = 0; n < out_desc.GetLength(I0); ++n) for(index_t n = 0; n < out_desc.GetLength(I0); ++n)
{ {
for(unsigned k = 0; k < out_desc.GetLength(I1); ++k) for(index_t k = 0; k < out_desc.GetLength(I1); ++k)
{ {
for(unsigned ho = 0; ho < out_desc.GetLength(I2); ++ho) for(index_t ho = 0; ho < out_desc.GetLength(I2); ++ho)
{ {
for(unsigned wo = 0; wo < out_desc.GetLength(I3); ++wo) for(index_t wo = 0; wo < out_desc.GetLength(I3); ++wo)
{ {
for(unsigned c = 0; c < wei_desc.GetLength(I1); ++c) for(index_t c = 0; c < wei_desc.GetLength(I1); ++c)
{ {
for(unsigned y = 0; y < wei_desc.GetLength(I2); ++y) for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
{ {
for(unsigned x = 0; x < wei_desc.GetLength(I3); ++x) for(index_t x = 0; x < wei_desc.GetLength(I3); ++x)
{ {
const unsigned hi = ho + y; const index_t hi = ho + y;
const unsigned wi = wo + x; const index_t wi = wo + x;
const unsigned in_index = in_desc.Get1dIndex(n, c, hi, wi); const index_t in_index = in_desc.Get1dIndex(n, c, hi, wi);
const unsigned wei_index = wei_desc.Get1dIndex(k, c, y, x); const index_t wei_index = wei_desc.Get1dIndex(k, c, y, x);
const unsigned out_index = out_desc.Get1dIndex(n, k, ho, wo); const index_t out_index = out_desc.Get1dIndex(n, k, ho, wo);
fused_multiply_accumulate( fused_multiply_accumulate(
p_out[out_index], p_wei[wei_index], p_in[in_index]); p_out[out_index], p_wei[wei_index], p_in[in_index]);
...@@ -125,7 +125,7 @@ __device__ void threadwise_direct_convolution_3(InDesc, ...@@ -125,7 +125,7 @@ __device__ void threadwise_direct_convolution_3(InDesc,
Data p_in_reg[in_reg_desc.GetElementSpace()]; Data p_in_reg[in_reg_desc.GetElementSpace()];
Data p_wei_reg[wei_reg_desc.GetElementSpace()]; Data p_wei_reg[wei_reg_desc.GetElementSpace()];
constexpr unsigned in_w_new_read = 1; constexpr index_t in_w_new_read = 1;
constexpr auto in_desc_reg_new_read = constexpr auto in_desc_reg_new_read =
make_ConstantTensorDescriptor(Sequence<in_reg_desc.GetLength(I0), make_ConstantTensorDescriptor(Sequence<in_reg_desc.GetLength(I0),
...@@ -136,7 +136,7 @@ __device__ void threadwise_direct_convolution_3(InDesc, ...@@ -136,7 +136,7 @@ __device__ void threadwise_direct_convolution_3(InDesc,
#if 0 #if 0
// this verison reused old input data in register, and read new data from LDS // this verison reused old input data in register, and read new data from LDS
// loop over vertical direction // loop over vertical direction
for(unsigned y = 0; y < wei_desc.GetLength(I2); ++y) for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
{ {
// read first input // read first input
threadwise_4d_tensor_copy(in_desc, threadwise_4d_tensor_copy(in_desc,
...@@ -157,7 +157,7 @@ __device__ void threadwise_direct_convolution_3(InDesc, ...@@ -157,7 +157,7 @@ __device__ void threadwise_direct_convolution_3(InDesc,
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out); in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
// loop over horizontal direction // loop over horizontal direction
for(unsigned x = 1; x < wei_desc.GetLength(I3); ++x) for(index_t x = 1; x < wei_desc.GetLength(I3); ++x)
{ {
// read new weight // read new weight
threadwise_4d_tensor_copy(wei_desc, threadwise_4d_tensor_copy(wei_desc,
...@@ -186,10 +186,10 @@ __device__ void threadwise_direct_convolution_3(InDesc, ...@@ -186,10 +186,10 @@ __device__ void threadwise_direct_convolution_3(InDesc,
#elif 1 #elif 1
// this version read all input from LDS when filter moves // this version read all input from LDS when filter moves
// loop over vertical direction // loop over vertical direction
for(unsigned y = 0; y < wei_desc.GetLength(I2); ++y) for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
{ {
// loop over horizontal direction // loop over horizontal direction
for(unsigned x = 0; x < wei_desc.GetLength(I3); ++x) for(index_t x = 0; x < wei_desc.GetLength(I3); ++x)
{ {
// read new weight // read new weight
threadwise_4d_tensor_copy(wei_desc, threadwise_4d_tensor_copy(wei_desc,
......
#pragma once #pragma once
template <class Float, class SrcMatrix, class DstMatrix, unsigned NRow, unsigned NCol> template <class Float, class SrcMatrix, class DstMatrix, index_t NRow, index_t NCol>
__device__ void threadwise_matrix_copy(SrcMatrix, __device__ void threadwise_matrix_copy(SrcMatrix,
const Float* __restrict__ p_src, const Float* __restrict__ p_src,
DstMatrix, DstMatrix,
...@@ -10,16 +10,39 @@ __device__ void threadwise_matrix_copy(SrcMatrix, ...@@ -10,16 +10,39 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
constexpr auto src_mtx = SrcMatrix{}; constexpr auto src_mtx = SrcMatrix{};
constexpr auto dst_mtx = DstMatrix{}; constexpr auto dst_mtx = DstMatrix{};
for(unsigned i = 0; i < NRow; ++i) #if 0
for(index_t i = 0; i < NRow; ++i)
{ {
for(unsigned j = 0; j < NCol; ++j) for(index_t j = 0; j < NCol; ++j)
{ {
const unsigned src_index = src_mtx.Get1dIndex(i, j); const index_t src_index = src_mtx.Get1dIndex(i, j);
const unsigned dst_index = dst_mtx.Get1dIndex(i, j); const index_t dst_index = dst_mtx.Get1dIndex(i, j);
p_dst[dst_index] = p_src[src_index]; p_dst[dst_index] = p_src[src_index];
} }
} }
#elif 1
static_assert(NCol == 4, "only for NCol == 4");
using vector_t = typename vector_type<Float, 4>::MemoryType;
for(index_t i = 0; i < NRow; ++i)
{
const index_t src_index = src_mtx.Get1dIndex(i, 0);
const index_t dst_index = dst_mtx.Get1dIndex(i, 0);
#if 1
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
*(reinterpret_cast<const vector_t*>(p_src + src_index));
#elif 1
asm volatile("\n \
ds_read_b128 %0, %1, offset:0 \n \
"
: "=v"(*(reinterpret_cast<vector_t*>(p_dst+dst_index)))
: "v"((uint32_t)(p_src + src_index)));
#endif
}
#endif
} }
template <class MatrixA, template <class MatrixA,
...@@ -49,21 +72,31 @@ __device__ void threadwise_gemm(MatrixA, ...@@ -49,21 +72,31 @@ __device__ void threadwise_gemm(MatrixA,
constexpr auto b_mtx = MatrixB{}; constexpr auto b_mtx = MatrixB{};
constexpr auto c_mtx = MatrixC{}; constexpr auto c_mtx = MatrixC{};
constexpr unsigned M = c_mtx.NRow(); constexpr index_t M = c_mtx.NRow();
constexpr unsigned N = c_mtx.NCol(); constexpr index_t N = c_mtx.NCol();
constexpr unsigned K = a_mtx.NRow(); // A is transposed constexpr index_t K = a_mtx.NRow(); // A is transposed
for(unsigned k = 0; k < K; ++k) for(index_t k = 0; k < K; ++k)
{ {
for(unsigned i = 0; i < M; ++i) for(index_t i = 0; i < M; ++i)
{ {
for(unsigned j = 0; j < N; ++j) for(index_t j = 0; j < N; ++j)
{ {
const unsigned aindex = a_mtx.Get1dIndex(k, i); // A is transposed const index_t aindex = a_mtx.Get1dIndex(k, i); // A is transposed
const unsigned bindex = b_mtx.Get1dIndex(k, j); const index_t bindex = b_mtx.Get1dIndex(k, j);
const unsigned cindex = c_mtx.Get1dIndex(i, j); const index_t cindex = c_mtx.Get1dIndex(i, j);
#if 0
f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]); f_accum(p_c_thread[cindex], p_a_thread[aindex] * p_b_thread[bindex]);
#elif 1
asm volatile("\n \
v_mac_f32 %0, %1, %2 \n \
"
: "=v"(p_c_thread[cindex])
: "v"(p_a_thread[aindex]),
"v"(p_b_thread[bindex]),
"0"(p_c_thread[cindex]));
#endif
} }
} }
} }
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "ConstantTensorDescriptor.hip.hpp" #include "ConstantTensorDescriptor.hip.hpp"
// need to assume src and dst is aligned // need to assume src and dst is aligned
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, unsigned DataPerRead> template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
__device__ void threadwise_6d_tensor_copy(SrcDesc, __device__ void threadwise_6d_tensor_copy(SrcDesc,
const Float* __restrict__ p_src, const Float* __restrict__ p_src,
DstDesc, DstDesc,
...@@ -37,28 +37,28 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc, ...@@ -37,28 +37,28 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc,
DstDesc{}.GetStride(I4) % DataPerRead == 0, DstDesc{}.GetStride(I4) % DataPerRead == 0,
"wrong! src and dst stride should be multiple of DataPerRead to keep alignment"); "wrong! src and dst stride should be multiple of DataPerRead to keep alignment");
constexpr unsigned L5 = SrcOpLengths{}.Get(I5); constexpr index_t L5 = SrcOpLengths{}.Get(I5);
static_assert(L5 % DataPerRead == 0, "wrong! L5 should be evenly divided by DataPerRead"); static_assert(L5 % DataPerRead == 0, "wrong! L5 should be evenly divided by DataPerRead");
constexpr unsigned nloop_d5 = L5 / DataPerRead; constexpr index_t nloop_d5 = L5 / DataPerRead;
for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0) for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
{ {
for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1) for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
{ {
for(unsigned did2 = 0; did2 < ref_desc.GetLength(I2); ++did2) for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
{ {
for(unsigned did3 = 0; did3 < ref_desc.GetLength(I3); ++did3) for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
{ {
for(unsigned did4 = 0; did4 < ref_desc.GetLength(I4); ++did4) for(index_t did4 = 0; did4 < ref_desc.GetLength(I4); ++did4)
{ {
for(unsigned iloop_d5 = 0; iloop_d5 < nloop_d5; ++iloop_d5) for(index_t iloop_d5 = 0; iloop_d5 < nloop_d5; ++iloop_d5)
{ {
const unsigned src_index = src_desc.Get1dIndex( const index_t src_index = src_desc.Get1dIndex(
did0, did1, did2, did3, did4, iloop_d5 * DataPerRead); did0, did1, did2, did3, did4, iloop_d5 * DataPerRead);
const unsigned dst_index = dst_desc.Get1dIndex( const index_t dst_index = dst_desc.Get1dIndex(
did0, did1, did2, did3, did4, iloop_d5 * DataPerRead); did0, did1, did2, did3, did4, iloop_d5 * DataPerRead);
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) = *(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
...@@ -72,7 +72,7 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc, ...@@ -72,7 +72,7 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc,
} }
// need to assume src and dst is aligned // need to assume src and dst is aligned
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, unsigned DataPerRead> template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
__device__ void threadwise_8d_tensor_copy(SrcDesc, __device__ void threadwise_8d_tensor_copy(SrcDesc,
const Float* __restrict__ p_src, const Float* __restrict__ p_src,
DstDesc, DstDesc,
...@@ -109,29 +109,29 @@ __device__ void threadwise_8d_tensor_copy(SrcDesc, ...@@ -109,29 +109,29 @@ __device__ void threadwise_8d_tensor_copy(SrcDesc,
DstDesc{}.GetStride(I6) % DataPerRead == 0, DstDesc{}.GetStride(I6) % DataPerRead == 0,
"wrong! src and dst stride should be multiple of DataPerRead to keep alignment"); "wrong! src and dst stride should be multiple of DataPerRead to keep alignment");
constexpr unsigned L7 = SrcOpLengths{}.Get(I7); constexpr index_t L7 = SrcOpLengths{}.Get(I7);
static_assert(L7 % DataPerRead == 0, "wrong! L7 should be evenly divided by DataPerRead"); static_assert(L7 % DataPerRead == 0, "wrong! L7 should be evenly divided by DataPerRead");
constexpr unsigned nloop_d7 = L7 / DataPerRead; constexpr index_t nloop_d7 = L7 / DataPerRead;
for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0) for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
{ {
for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1) for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
{ {
for(unsigned did2 = 0; did2 < ref_desc.GetLength(I2); ++did2) for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
{ {
for(unsigned did3 = 0; did3 < ref_desc.GetLength(I3); ++did3) for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
{ {
for(unsigned did4 = 0; did4 < ref_desc.GetLength(I4); ++did4) for(index_t did4 = 0; did4 < ref_desc.GetLength(I4); ++did4)
{ {
for(unsigned did5 = 0; did5 < ref_desc.GetLength(I5); ++did5) for(index_t did5 = 0; did5 < ref_desc.GetLength(I5); ++did5)
{ {
for(unsigned did6 = 0; did6 < ref_desc.GetLength(I6); ++did6) for(index_t did6 = 0; did6 < ref_desc.GetLength(I6); ++did6)
{ {
for(unsigned iloop_d7 = 0; iloop_d7 < nloop_d7; ++iloop_d7) for(index_t iloop_d7 = 0; iloop_d7 < nloop_d7; ++iloop_d7)
{ {
const unsigned src_index = const index_t src_index =
src_desc.Get1dIndex(did0, src_desc.Get1dIndex(did0,
did1, did1,
did2, did2,
...@@ -141,7 +141,7 @@ __device__ void threadwise_8d_tensor_copy(SrcDesc, ...@@ -141,7 +141,7 @@ __device__ void threadwise_8d_tensor_copy(SrcDesc,
did6, did6,
iloop_d7 * DataPerRead); iloop_d7 * DataPerRead);
const unsigned dst_index = const index_t dst_index =
dst_desc.Get1dIndex(did0, dst_desc.Get1dIndex(did0,
did1, did1,
did2, did2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment