Commit 2603bb0f authored by Chao Liu's avatar Chao Liu
Browse files

tuning on vega 20

parent a9031464
...@@ -140,7 +140,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -140,7 +140,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t WeiBlockCopyDataPerRead_K = 4; constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 2; constexpr index_t OutThreadCopyDataPerWrite_N = 2;
#elif 1 #elif 0
// for 3x3, 34x34, v1r3, Pascal // for 3x3, 34x34, v1r3, Pascal
// for 3x3, 28x28, v1r3, Pascal // for 3x3, 28x28, v1r3, Pascal
// for 3x3, 14x14, v1r3, Pascal // for 3x3, 14x14, v1r3, Pascal
...@@ -206,6 +206,8 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -206,6 +206,8 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t OutThreadCopyDataPerWrite_N = 1; constexpr index_t OutThreadCopyDataPerWrite_N = 1;
#elif 0 #elif 0
// for 3x3, 34x34, v1r1, Vega 20 // for 3x3, 34x34, v1r1, Vega 20
constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128; constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 4; constexpr index_t CPerBlock = 4;
...@@ -227,16 +229,43 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -227,16 +229,43 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 4;
constexpr index_t InBlockCopy_ThreadPerDimC = 4; using InBlockCopyClusterLengths_CHWN = Sequence<4, 4, 2, 8>;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
constexpr index_t InBlockCopyDataPerRead_N = 2; constexpr index_t InBlockCopyDataPerRead_N = 2;
constexpr index_t WeiBlockCopyDataPerRead_K = 2; constexpr index_t WeiBlockCopyDataPerRead_K = 2;
constexpr index_t OutThreadCopyDataPerWrite_N = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 4;
#elif 1
// for 3x3, 34x34, v1r3, Vega 20
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 2;
constexpr index_t WoPerBlock = 4;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 4, 4>;
constexpr index_t InBlockCopyDataPerRead_N = 4;
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
constexpr index_t OutThreadCopyDataPerWrite_N = 4;
#elif 0 #elif 0
// for 3x3, 56x56, v1r1, Pascal // for 3x3, 56x56, v1r1, Pascal
constexpr index_t NPerBlock = 32; constexpr index_t NPerBlock = 32;
...@@ -448,7 +477,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, ...@@ -448,7 +477,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
#elif 1 #elif 1
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
#elif 0 #elif 1
GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
#endif #endif
<GridSize, <GridSize,
......
...@@ -182,7 +182,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc, ...@@ -182,7 +182,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
constexpr auto gridwise_conv = constexpr auto gridwise_conv =
#if 0 #if 0
GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
#elif 1 #elif 0
GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
#elif 1 #elif 1
GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
......
...@@ -603,9 +603,9 @@ int main(int argc, char* argv[]) ...@@ -603,9 +603,9 @@ int main(int argc, char* argv[])
device_direct_convolution_2_nchw_kcyx_nkhw device_direct_convolution_2_nchw_kcyx_nkhw
#elif 0 #elif 0
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 0 #elif 1
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
#elif 0 #elif 1
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
#elif 1 #elif 1
device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw
......
...@@ -116,11 +116,7 @@ struct ConstantTensorDescriptor ...@@ -116,11 +116,7 @@ struct ConstantTensorDescriptor
static_for<0, nDim, 1>{}([&](auto IDim) { static_for<0, nDim, 1>{}([&](auto IDim) {
constexpr index_t idim = IDim.Get(); constexpr index_t idim = IDim.Get();
#if DEVICE_BACKEND_HIP
id += __mul24(multi_id[idim], GetStride(IDim));
#else
id += multi_id[idim] * GetStride(IDim); id += multi_id[idim] * GetStride(IDim);
#endif
}); });
return id; return id;
......
...@@ -213,7 +213,6 @@ struct Blockwise3dTensorCopy3 ...@@ -213,7 +213,6 @@ struct Blockwise3dTensorCopy3
#pragma unroll #pragma unroll
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2) for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
{ {
#pragma unroll
const index_t src_offset = const index_t src_offset =
SrcDesc{}.Get1dIndex(iloop_d0 * thread_per_d0, SrcDesc{}.Get1dIndex(iloop_d0 * thread_per_d0,
iloop_d1 * thread_per_d1, iloop_d1 * thread_per_d1,
......
...@@ -341,7 +341,8 @@ struct BlockwiseChwnTensorCopyPadded ...@@ -341,7 +341,8 @@ struct BlockwiseChwnTensorCopyPadded
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize; constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
const Float* p_src_tmp = const Float* p_src_tmp =
p_src + src_desc.Get1dIndex(c_block_data_begin, p_src +
src_desc.Get1dIndex(c_block_data_begin,
(ho_block_data_begin + h_block_pad_low) - h_global_pad_low, (ho_block_data_begin + h_block_pad_low) - h_global_pad_low,
(wo_block_data_begin + w_block_pad_low) - w_global_pad_low, (wo_block_data_begin + w_block_pad_low) - w_global_pad_low,
n_block_data_begin); n_block_data_begin);
......
...@@ -404,7 +404,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -404,7 +404,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{ {
threadwise_matrix_copy( threadwise_matrix_copy(
c_thread_sub_mtx, c_thread_sub_mtx,
p_c_thread + c_thread_sub_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster, p_c_thread +
c_thread_sub_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster,
n_repeat * NPerLevel1Cluster), n_repeat * NPerLevel1Cluster),
c_block_mtx, c_block_mtx,
p_c_block + p_c_block +
......
...@@ -93,7 +93,8 @@ __device__ void blockwise_direct_convolution(InBlockDesc, ...@@ -93,7 +93,8 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
Float p_out_thread[out_thread_desc.GetElementSpace()]; Float p_out_thread[out_thread_desc.GetElementSpace()];
threadwise_4d_tensor_copy(out_block_desc, threadwise_4d_tensor_copy(out_block_desc,
p_out_block + out_block_desc.Get1dIndex(n_thread_data_begin, p_out_block +
out_block_desc.Get1dIndex(n_thread_data_begin,
k_thread_data_begin, k_thread_data_begin,
ho_thread_data_begin, ho_thread_data_begin,
wo_thread_data_begin), wo_thread_data_begin),
...@@ -107,7 +108,8 @@ __device__ void blockwise_direct_convolution(InBlockDesc, ...@@ -107,7 +108,8 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
// threadwise convolution // threadwise convolution
threadwise_direct_convolution_2( threadwise_direct_convolution_2(
in_thread_block_desc, in_thread_block_desc,
p_in_block + in_block_desc.Get1dIndex(n_thread_data_begin, p_in_block +
in_block_desc.Get1dIndex(n_thread_data_begin,
c_thread_data_begin, c_thread_data_begin,
hi_thread_data_begin, hi_thread_data_begin,
wi_thread_data_begin), wi_thread_data_begin),
...@@ -122,7 +124,8 @@ __device__ void blockwise_direct_convolution(InBlockDesc, ...@@ -122,7 +124,8 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
threadwise_4d_tensor_copy(out_thread_desc, threadwise_4d_tensor_copy(out_thread_desc,
p_out_thread, p_out_thread,
out_block_desc, out_block_desc,
p_out_block + out_block_desc.Get1dIndex(n_thread_data_begin, p_out_block +
out_block_desc.Get1dIndex(n_thread_data_begin,
k_thread_data_begin, k_thread_data_begin,
ho_thread_data_begin, ho_thread_data_begin,
wo_thread_data_begin), wo_thread_data_begin),
......
...@@ -56,7 +56,7 @@ struct BlockwiseNdTensorCopyReorder_v3 ...@@ -56,7 +56,7 @@ struct BlockwiseNdTensorCopyReorder_v3
"wrong! BlockSize is not big enough for ThreadPerDims!"); "wrong! BlockSize is not big enough for ThreadPerDims!");
// sanity check: work division // sanity check: work division
static_for<0, nDim, 1>{}([](auto IDim) { static_for<0, nDim, 1>{}([&](auto IDim) {
constexpr auto I = decltype(IDim){}; constexpr auto I = decltype(IDim){};
constexpr index_t src_len = src_lengths.Get(I); constexpr index_t src_len = src_lengths.Get(I);
constexpr index_t src_sub_len = src_sub_lengths.Get(I); constexpr index_t src_sub_len = src_sub_lengths.Get(I);
...@@ -220,7 +220,7 @@ struct BlockwiseNdTensorCopyReorder_v3 ...@@ -220,7 +220,7 @@ struct BlockwiseNdTensorCopyReorder_v3
constexpr index_t dst_offset = DstDesc{}.Get1dIndex(dst_data_multi_id); constexpr index_t dst_offset = DstDesc{}.Get1dIndex(dst_data_multi_id);
// write in the order of dst // write in the order of dst
#if 1 #if 1
threadwise_nd_tensor_copy_reorder_given_dst2src_v2(thread_tensor_desc, threadwise_nd_tensor_copy_reorder_given_dst2src_v2(thread_tensor_desc,
p_clipboard + clipboard_offset, p_clipboard + clipboard_offset,
......
...@@ -43,10 +43,11 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn ...@@ -43,10 +43,11 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
Float* const __restrict__ p_out_global) const Float* const __restrict__ p_out_global) const
{ {
// be careful of this assertion // be careful of this assertion
static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock && static_assert(
NPerBlock % GemmNPerThreadSubC == 0) || NPerBlock % NPerThread == 0 &&
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock && (GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0), GemmNPerThreadSubC % NPerThread == 0)),
"wrong!"); "wrong!");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -219,7 +220,8 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn ...@@ -219,7 +220,8 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread); threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
const Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + in_c_h_w_n_global_desc.Get1dIndex( p_in_global +
in_c_h_w_n_global_desc.Get1dIndex(
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin); 0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
const Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
...@@ -275,20 +277,21 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn ...@@ -275,20 +277,21 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
static_if<GemmNPerThreadSubC <= NPerBlock>{}( static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
[&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to // perfect forwarding.
// Using this trick to
// make this lambda a generic lambda, so it won't be compiled until // make this lambda a generic lambda, so it won't be compiled until
// instantiated // instantiated
static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock && static_assert(
NPerBlock % GemmNPerThreadSubC == 0), (f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
"wrong!"); "wrong!");
// output is a 10d tensor // output is a 10d tensor
constexpr index_t N2 = GemmNPerThreadSubC; constexpr index_t N2 = GemmNPerThreadSubC;
constexpr index_t N1 = NPerBlock / N2; constexpr index_t N1 = NPerBlock / N2;
constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) / constexpr index_t W2 =
f_dummy(NPerBlock / GemmNPerThreadSubC); (GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
constexpr index_t W1 = WoPerBlock / W2; constexpr index_t W1 = WoPerBlock / W2;
constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K2 = GemmMPerThreadSubC;
...@@ -322,19 +325,18 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn ...@@ -322,19 +325,18 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
} }
#endif #endif
threadwise_nd_tensor_copy(out_10d_thread_desc, threadwise_nd_tensor_copy(
out_10d_thread_desc,
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + p_out_global +
out_k_h_w_n_global_desc.Get1dIndex( out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(), out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{}); Number<OutThreadCopyDataPerWrite_N>{});
}) }).else_([&](auto f_dummy) {
.else_([&](auto f_dummy) {
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0, GemmNPerThreadSubC % NPerThread == 0,
"wrong!"); "wrong!");
...@@ -349,17 +351,8 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn ...@@ -349,17 +351,8 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread; constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc = constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2), Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3, N / N1, N1>{});
K1,
K2,
Ho,
Wo / (W1 * W2 * W3),
W1,
W2,
W3,
N / N1,
N1>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{}); Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
...@@ -382,12 +375,12 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn ...@@ -382,12 +375,12 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
} }
#endif #endif
threadwise_nd_tensor_copy(out_10d_thread_desc, threadwise_nd_tensor_copy(
out_10d_thread_desc,
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + p_out_global +
out_k_h_w_n_global_desc.Get1dIndex( out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
......
...@@ -44,10 +44,11 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -44,10 +44,11 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
Float* const __restrict__ p_out_global) const Float* const __restrict__ p_out_global) const
{ {
// be careful of this assertion // be careful of this assertion
static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock && static_assert(
NPerBlock % GemmNPerThreadSubC == 0) || NPerBlock % NPerThread == 0 &&
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock && (GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0), GemmNPerThreadSubC % NPerThread == 0)),
"wrong!"); "wrong!");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -125,8 +126,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -125,8 +126,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor( constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{}); Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
// blockwise copy // blockwise copy
// input: format is [C, Hi, Wi, N] // input: format is [C, Hi, Wi, N]
#if 1 #if 1
const auto blockwise_in_copy = const auto blockwise_in_copy =
Blockwise4dTensorCopy1<BlockSize, Blockwise4dTensorCopy1<BlockSize,
...@@ -228,7 +229,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -228,7 +229,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
#if 1 #if 1
const Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + in_c_h_w_n_global_desc.Get1dIndex( p_in_global +
in_c_h_w_n_global_desc.Get1dIndex(
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin); 0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
const Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
...@@ -273,12 +275,12 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -273,12 +275,12 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
const Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, 0, k_block_data_begin); p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, 0, k_block_data_begin);
for(index_t c_block_data_begin = 0; c_block_data_begin < C; for(index_t
c_block_data_begin = 0;
c_block_data_begin < C;
c_block_data_begin += CPerBlock, c_block_data_begin += CPerBlock,
p_in_global_block_offset += p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0),
CPerBlock * in_c_h_w_n_global_desc.GetStride(I0), p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
p_wei_global_block_offset +=
CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
{ {
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
...@@ -308,20 +310,21 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -308,20 +310,21 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
static_if<GemmNPerThreadSubC <= NPerBlock>{}( static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
[&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to // perfect forwarding.
// Using this trick to
// make this lambda a generic lambda, so it won't be compiled until // make this lambda a generic lambda, so it won't be compiled until
// instantiated // instantiated
static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock && static_assert(
NPerBlock % GemmNPerThreadSubC == 0), (f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
"wrong!"); "wrong!");
// output is a 10d tensor // output is a 10d tensor
constexpr index_t N2 = GemmNPerThreadSubC; constexpr index_t N2 = GemmNPerThreadSubC;
constexpr index_t N1 = NPerBlock / N2; constexpr index_t N1 = NPerBlock / N2;
constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) / constexpr index_t W2 =
f_dummy(NPerBlock / GemmNPerThreadSubC); (GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
constexpr index_t W1 = WoPerBlock / W2; constexpr index_t W1 = WoPerBlock / W2;
constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K2 = GemmMPerThreadSubC;
...@@ -355,19 +358,18 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -355,19 +358,18 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
} }
#endif #endif
threadwise_nd_tensor_copy(out_10d_thread_desc, threadwise_nd_tensor_copy(
out_10d_thread_desc,
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + p_out_global +
out_k_h_w_n_global_desc.Get1dIndex( out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(), out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{}); Number<OutThreadCopyDataPerWrite_N>{});
}) }).else_([&](auto f_dummy) {
.else_([&](auto f_dummy) {
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0, GemmNPerThreadSubC % NPerThread == 0,
"wrong!"); "wrong!");
...@@ -382,17 +384,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -382,17 +384,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread; constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc = constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2), Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3, N / N1, N1>{});
K1,
K2,
Ho,
Wo / (W1 * W2 * W3),
W1,
W2,
W3,
N / N1,
N1>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{}); Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
...@@ -415,12 +408,12 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -415,12 +408,12 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
} }
#endif #endif
threadwise_nd_tensor_copy(out_10d_thread_desc, threadwise_nd_tensor_copy(
out_10d_thread_desc,
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + p_out_global +
out_k_h_w_n_global_desc.Get1dIndex( out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
......
...@@ -49,8 +49,11 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn ...@@ -49,8 +49,11 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
{ {
// be careful of this assertion // be careful of this assertion
static_assert( static_assert(
NPerThread <= NPerBlock && NPerBlock % NPerThread == 0, NPerBlock % NPerThread == 0 &&
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0"); ((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0)),
"wrong!");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -262,12 +265,12 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn ...@@ -262,12 +265,12 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
const Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, 0, k_block_data_begin); p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, 0, k_block_data_begin);
for(index_t c_block_data_begin = 0; c_block_data_begin < C; for(index_t
c_block_data_begin = 0;
c_block_data_begin < C;
c_block_data_begin += CPerBlock, c_block_data_begin += CPerBlock,
p_in_global_block_offset += p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
CPerBlock * in_n_c_h_w_global_desc.GetStride(I1), p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
p_wei_global_block_offset +=
CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
{ {
Float p_in_clipboard[blockwise_in_copy_reorder.GetRegisterClipboardSize()]; Float p_in_clipboard[blockwise_in_copy_reorder.GetRegisterClipboardSize()];
Float p_wei_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; Float p_wei_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
...@@ -333,11 +336,12 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn ...@@ -333,11 +336,12 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
} }
#endif #endif
threadwise_10d_tensor_copy(out_10d_thread_desc, threadwise_10d_tensor_copy(
out_10d_thread_desc,
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + out_k_h_w_n_global_desc.Get1dIndex( p_out_global +
k_block_data_begin + k_thread_data_begin, out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
......
...@@ -43,10 +43,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -43,10 +43,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
Float* const __restrict__ p_out_global) const Float* const __restrict__ p_out_global) const
{ {
// be careful of this assertion // be careful of this assertion
static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock && static_assert(
NPerBlock % GemmNPerThreadSubC == 0) || NPerBlock % NPerThread == 0 &&
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock && (GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0), GemmNPerThreadSubC % NPerThread == 0)),
"wrong!"); "wrong!");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -212,9 +213,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -212,9 +213,10 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
// set threadwise output tensor to 0 // set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread); threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
#if 1 #if 0
const Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + in_c_h_w_n_global_desc.Get1dIndex( p_in_global +
in_c_h_w_n_global_desc.Get1dIndex(
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin); 0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
const Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
...@@ -226,6 +228,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -226,6 +228,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
{ {
for(index_t y = 0; y < Y; ++y) for(index_t y = 0; y < Y; ++y)
{ {
#pragma unroll
for(index_t x = 0; x < X; ++x) for(index_t x = 0; x < X; ++x)
{ {
blockwise_in_copy.Run(p_in_global_block_offset + blockwise_in_copy.Run(p_in_global_block_offset +
...@@ -287,20 +290,21 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -287,20 +290,21 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
static_if<GemmNPerThreadSubC <= NPerBlock>{}( static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
[&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to // perfect forwarding.
// Using this trick to
// make this lambda a generic lambda, so it won't be compiled until // make this lambda a generic lambda, so it won't be compiled until
// instantiated // instantiated
static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock && static_assert(
NPerBlock % GemmNPerThreadSubC == 0), (f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
"wrong!"); "wrong!");
// output is a 10d tensor // output is a 10d tensor
constexpr index_t N2 = GemmNPerThreadSubC; constexpr index_t N2 = GemmNPerThreadSubC;
constexpr index_t N1 = NPerBlock / N2; constexpr index_t N1 = NPerBlock / N2;
constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) / constexpr index_t W2 =
f_dummy(NPerBlock / GemmNPerThreadSubC); (GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
constexpr index_t W1 = WoPerBlock / W2; constexpr index_t W1 = WoPerBlock / W2;
constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K2 = GemmMPerThreadSubC;
...@@ -334,19 +338,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -334,19 +338,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
} }
#endif #endif
threadwise_nd_tensor_copy(out_10d_thread_desc, threadwise_nd_tensor_copy(
out_10d_thread_desc,
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + p_out_global +
out_k_h_w_n_global_desc.Get1dIndex( out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(), out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{}); Number<OutThreadCopyDataPerWrite_N>{});
}) }).else_([&](auto f_dummy) {
.else_([&](auto f_dummy) {
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0, GemmNPerThreadSubC % NPerThread == 0,
"wrong!"); "wrong!");
...@@ -361,17 +364,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -361,17 +364,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread; constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc = constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2), Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3, N / N1, N1>{});
K1,
K2,
Ho,
Wo / (W1 * W2 * W3),
W1,
W2,
W3,
N / N1,
N1>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{}); Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
...@@ -394,12 +388,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -394,12 +388,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
} }
#endif #endif
threadwise_nd_tensor_copy(out_10d_thread_desc, threadwise_nd_tensor_copy(
out_10d_thread_desc,
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + p_out_global +
out_k_h_w_n_global_desc.Get1dIndex( out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
......
...@@ -43,10 +43,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -43,10 +43,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
Float* const __restrict__ p_out_global) const Float* const __restrict__ p_out_global) const
{ {
// be careful of this assertion // be careful of this assertion
static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock && static_assert(
NPerBlock % GemmNPerThreadSubC == 0) || NPerBlock % NPerThread == 0 &&
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock && (GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0), GemmNPerThreadSubC % NPerThread == 0)),
"wrong!"); "wrong!");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -127,8 +128,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -127,8 +128,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor( constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{}); Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
// blockwise copy // blockwise copy
// input: format is [C, Hi, Wi, N] // input: format is [C, Hi, Wi, N]
#if 0 #if 0
const auto blockwise_in_copy = const auto blockwise_in_copy =
Blockwise4dTensorCopy1<BlockSize, Blockwise4dTensorCopy1<BlockSize,
...@@ -349,20 +350,21 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -349,20 +350,21 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
static_if<GemmNPerThreadSubC <= NPerBlock>{}( static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
[&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to // perfect forwarding.
// Using this trick to
// make this lambda a generic lambda, so it won't be compiled until // make this lambda a generic lambda, so it won't be compiled until
// instantiated // instantiated
static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock && static_assert(
NPerBlock % GemmNPerThreadSubC == 0), (f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
"wrong!"); "wrong!");
// output is a 10d tensor // output is a 10d tensor
constexpr index_t N2 = GemmNPerThreadSubC; constexpr index_t N2 = GemmNPerThreadSubC;
constexpr index_t N1 = NPerBlock / N2; constexpr index_t N1 = NPerBlock / N2;
constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) / constexpr index_t W2 =
f_dummy(NPerBlock / GemmNPerThreadSubC); (GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
constexpr index_t W1 = WoPerBlock / W2; constexpr index_t W1 = WoPerBlock / W2;
constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K2 = GemmMPerThreadSubC;
...@@ -396,19 +398,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -396,19 +398,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
} }
#endif #endif
threadwise_nd_tensor_copy(out_10d_thread_desc, threadwise_nd_tensor_copy(
out_10d_thread_desc,
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + p_out_global +
out_k_h_w_n_global_desc.Get1dIndex( out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(), out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{}); Number<OutThreadCopyDataPerWrite_N>{});
}) }).else_([&](auto f_dummy) {
.else_([&](auto f_dummy) {
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0, GemmNPerThreadSubC % NPerThread == 0,
"wrong!"); "wrong!");
...@@ -423,17 +424,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -423,17 +424,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread; constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc = constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2), Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3, N / N1, N1>{});
K1,
K2,
Ho,
Wo / (W1 * W2 * W3),
W1,
W2,
W3,
N / N1,
N1>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{}); Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
...@@ -456,12 +448,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -456,12 +448,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
} }
#endif #endif
threadwise_nd_tensor_copy(out_10d_thread_desc, threadwise_nd_tensor_copy(
out_10d_thread_desc,
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + p_out_global +
out_k_h_w_n_global_desc.Get1dIndex( out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
......
...@@ -47,10 +47,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn ...@@ -47,10 +47,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
Float* const __restrict__ p_out_global) const Float* const __restrict__ p_out_global) const
{ {
// be careful of this assertion // be careful of this assertion
static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock && static_assert(
NPerBlock % GemmNPerThreadSubC == 0) || NPerBlock % NPerThread == 0 &&
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock && (GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0), GemmNPerThreadSubC % NPerThread == 0)),
"wrong!"); "wrong!");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -349,20 +350,21 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn ...@@ -349,20 +350,21 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
static_if<GemmNPerThreadSubC <= NPerBlock>{}( static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
[&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to // perfect forwarding.
// Using this trick to
// make this lambda a generic lambda, so it won't be compiled until // make this lambda a generic lambda, so it won't be compiled until
// instantiated // instantiated
static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock && static_assert(
NPerBlock % GemmNPerThreadSubC == 0), (f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
"wrong!"); "wrong!");
// output is a 10d tensor // output is a 10d tensor
constexpr index_t N2 = GemmNPerThreadSubC; constexpr index_t N2 = GemmNPerThreadSubC;
constexpr index_t N1 = NPerBlock / N2; constexpr index_t N1 = NPerBlock / N2;
constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) / constexpr index_t W2 =
f_dummy(NPerBlock / GemmNPerThreadSubC); (GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
constexpr index_t W1 = WoPerBlock / W2; constexpr index_t W1 = WoPerBlock / W2;
constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K2 = GemmMPerThreadSubC;
...@@ -396,19 +398,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn ...@@ -396,19 +398,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
} }
#endif #endif
threadwise_nd_tensor_copy(out_10d_thread_desc, threadwise_nd_tensor_copy(
out_10d_thread_desc,
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + p_out_global +
out_k_h_w_n_global_desc.Get1dIndex( out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(), out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{}); Number<OutThreadCopyDataPerWrite_N>{});
}) }).else_([&](auto f_dummy) {
.else_([&](auto f_dummy) {
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0, GemmNPerThreadSubC % NPerThread == 0,
"wrong!"); "wrong!");
...@@ -423,17 +424,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn ...@@ -423,17 +424,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread; constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc = constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2), Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3, N / N1, N1>{});
K1,
K2,
Ho,
Wo / (W1 * W2 * W3),
W1,
W2,
W3,
N / N1,
N1>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{}); Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
...@@ -456,12 +448,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn ...@@ -456,12 +448,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
} }
#endif #endif
threadwise_nd_tensor_copy(out_10d_thread_desc, threadwise_nd_tensor_copy(
out_10d_thread_desc,
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + p_out_global +
out_k_h_w_n_global_desc.Get1dIndex( out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
......
...@@ -47,10 +47,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn ...@@ -47,10 +47,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
Float* const __restrict__ p_out_global) const Float* const __restrict__ p_out_global) const
{ {
// be careful of this assertion // be careful of this assertion
static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock && static_assert(
NPerBlock % GemmNPerThreadSubC == 0) || NPerBlock % NPerThread == 0 &&
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock && (GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0), GemmNPerThreadSubC % NPerThread == 0)),
"wrong!"); "wrong!");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -223,7 +224,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn ...@@ -223,7 +224,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
#if 1 #if 1
const Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + in_n_c_h_w_global_desc.Get1dIndex( p_in_global +
in_n_c_h_w_global_desc.Get1dIndex(
n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin); n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin);
const Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
...@@ -329,20 +331,21 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn ...@@ -329,20 +331,21 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
static_if<GemmNPerThreadSubC <= NPerBlock>{}( static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
[&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to // perfect forwarding.
// Using this trick to
// make this lambda a generic lambda, so it won't be compiled until // make this lambda a generic lambda, so it won't be compiled until
// instantiated // instantiated
static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock && static_assert(
NPerBlock % GemmNPerThreadSubC == 0), (f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
"wrong!"); "wrong!");
// output is a 10d tensor // output is a 10d tensor
constexpr index_t N2 = GemmNPerThreadSubC; constexpr index_t N2 = GemmNPerThreadSubC;
constexpr index_t N1 = NPerBlock / N2; constexpr index_t N1 = NPerBlock / N2;
constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) / constexpr index_t W2 =
f_dummy(NPerBlock / GemmNPerThreadSubC); (GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
constexpr index_t W1 = WoPerBlock / W2; constexpr index_t W1 = WoPerBlock / W2;
constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K2 = GemmMPerThreadSubC;
...@@ -376,19 +379,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn ...@@ -376,19 +379,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
} }
#endif #endif
threadwise_nd_tensor_copy(out_10d_thread_desc, threadwise_nd_tensor_copy(
out_10d_thread_desc,
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + p_out_global +
out_k_h_w_n_global_desc.Get1dIndex( out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(), out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{}); Number<OutThreadCopyDataPerWrite_N>{});
}) }).else_([&](auto f_dummy) {
.else_([&](auto f_dummy) {
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0, GemmNPerThreadSubC % NPerThread == 0,
"wrong!"); "wrong!");
...@@ -403,17 +405,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn ...@@ -403,17 +405,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread; constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc = constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2), Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3, N / N1, N1>{});
K1,
K2,
Ho,
Wo / (W1 * W2 * W3),
W1,
W2,
W3,
N / N1,
N1>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{}); Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
...@@ -436,12 +429,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn ...@@ -436,12 +429,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
} }
#endif #endif
threadwise_nd_tensor_copy(out_10d_thread_desc, threadwise_nd_tensor_copy(
out_10d_thread_desc,
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + p_out_global +
out_k_h_w_n_global_desc.Get1dIndex( out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
......
...@@ -47,10 +47,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw ...@@ -47,10 +47,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
Float* const __restrict__ p_out_global) const Float* const __restrict__ p_out_global) const
{ {
// be careful of this assertion // be careful of this assertion
static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock && static_assert(
NPerBlock % GemmNPerThreadSubC == 0) || NPerBlock % NPerThread == 0 &&
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock && (GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0), GemmNPerThreadSubC % NPerThread == 0)),
"wrong!"); "wrong!");
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -223,7 +224,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw ...@@ -223,7 +224,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
#if 1 #if 1
const Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + in_n_c_h_w_global_desc.Get1dIndex( p_in_global +
in_n_c_h_w_global_desc.Get1dIndex(
n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin); n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin);
const Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
...@@ -409,14 +411,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw ...@@ -409,14 +411,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
out_10d_thread_desc, out_10d_thread_desc,
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + out_n_k_h_w_global_desc.Get1dIndex( p_out_global +
out_n_k_h_w_global_desc.Get1dIndex(
n_block_data_begin + n_thread_data_begin, n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin), wo_block_data_begin + wo_thread_data_begin),
out_10d_thread_desc.GetLengths(), out_10d_thread_desc.GetLengths(),
map_out_global2thread); map_out_global2thread);
// Number<OutThreadCopyDataPerWrite_W>{}); // Number<OutThreadCopyDataPerWrite_W>{});
#endif #endif
}) })
.else_([&](auto f_dummy) { .else_([&](auto f_dummy) {
...@@ -500,14 +503,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw ...@@ -500,14 +503,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
out_10d_thread_desc, out_10d_thread_desc,
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + out_n_k_h_w_global_desc.Get1dIndex( p_out_global +
out_n_k_h_w_global_desc.Get1dIndex(
n_block_data_begin + n_thread_data_begin, n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin), wo_block_data_begin + wo_thread_data_begin),
out_10d_thread_desc.GetLengths(), out_10d_thread_desc.GetLengths(),
map_out_global2thread); map_out_global2thread);
// Number<OutThreadCopyDataPerWrite_W>{}); // Number<OutThreadCopyDataPerWrite_W>{});
#endif #endif
}); });
} }
......
...@@ -365,11 +365,12 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -365,11 +365,12 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
constexpr auto out_kb_global_desc = make_ConstantTensorDescriptor(Sequence<K, B>{}); constexpr auto out_kb_global_desc = make_ConstantTensorDescriptor(Sequence<K, B>{});
threadwise_6d_tensor_copy(out_6d_thread_desc, threadwise_6d_tensor_copy(
out_6d_thread_desc,
p_out_thread, p_out_thread,
out_6d_global_desc, out_6d_global_desc,
p_out_global + out_kb_global_desc.Get1dIndex( p_out_global +
k_thread_data_begin, b_thread_data_begin), out_kb_global_desc.Get1dIndex(k_thread_data_begin, b_thread_data_begin),
out_6d_thread_desc.GetLengths(), out_6d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite>{}); Number<OutThreadCopyDataPerWrite>{});
} }
......
...@@ -113,7 +113,8 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_ ...@@ -113,7 +113,8 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
c_block_work_begin += CPerBlock) c_block_work_begin += CPerBlock)
{ {
// copy input tensor to LDS // copy input tensor to LDS
blockwise_in_copy.Run(p_in_global + in_global_desc.Get1dIndex(n_block_work_begin, blockwise_in_copy.Run(p_in_global +
in_global_desc.Get1dIndex(n_block_work_begin,
c_block_work_begin, c_block_work_begin,
hi_block_work_begin, hi_block_work_begin,
wi_block_work_begin), wi_block_work_begin),
...@@ -143,9 +144,9 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_ ...@@ -143,9 +144,9 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
} }
// copy output tensor from LDS to device mem // copy output tensor from LDS to device mem
blockwise_out_copy.Run(p_out_block, blockwise_out_copy.Run(
p_out_global + out_global_desc.Get1dIndex(n_block_work_begin, p_out_block,
k_block_work_begin, p_out_global +
ho_block_work_begin, out_global_desc.Get1dIndex(
wo_block_work_begin)); n_block_work_begin, k_block_work_begin, ho_block_work_begin, wo_block_work_begin));
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment