Commit acd7082f authored by Chao Liu's avatar Chao Liu
Browse files

adding ConstantMergedTensorDescriptor, refactering ConstantTensorDescriptor, Sequence

parent cd29b09a
...@@ -221,11 +221,11 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn ...@@ -221,11 +221,11 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
#if 0 #if 0
const Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + in_n_c_h_w_global_desc.Get1dIndex( p_in_global + in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(
n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin); n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin);
const Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); p_wei_global + wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin);
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1), p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
...@@ -234,20 +234,20 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn ...@@ -234,20 +234,20 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
for(index_t y = 0; y < Y; ++y) for(index_t y = 0; y < Y; ++y)
{ {
blockwise_in_copy_reorder.Run(p_in_global_block_offset + blockwise_in_copy_reorder.Run(p_in_global_block_offset +
in_n_c_h_w_global_desc.Get1dIndex(0, 0, y, 0), in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, 0),
p_in_block); p_in_block);
blockwise_wei_copy.Run(p_wei_global_block_offset + blockwise_wei_copy.Run(p_wei_global_block_offset +
wei_c_y_x_k_global_desc.Get1dIndex(0, y, 0, 0), wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, 0, 0),
p_wei_block); p_wei_block);
__syncthreads(); __syncthreads();
for(index_t x = 0; x < X; ++x) for(index_t x = 0; x < X; ++x)
{ {
blockwise_batch_gemm.Run(p_wei_block + wei_c_x_k_block_desc.Get1dIndex(0, x, 0), blockwise_batch_gemm.Run(p_wei_block + wei_c_x_k_block_desc.GetOffsetFromMultiIndex(0, x, 0),
p_in_block + p_in_block +
in_c_h_w_n_block_desc.Get1dIndex(0, 0, x, 0), in_c_h_w_n_block_desc.GetOffsetFromMultiIndex(0, 0, x, 0),
p_out_thread); p_out_thread);
} }
...@@ -259,11 +259,12 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn ...@@ -259,11 +259,12 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
{ {
const Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + p_in_global +
in_n_c_h_w_global_desc.Get1dIndex( in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(
n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin); n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin);
const Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, 0, k_block_data_begin); p_wei_global +
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, 0, k_block_data_begin);
for(index_t for(index_t
c_block_data_begin = 0; c_block_data_begin = 0;
...@@ -287,10 +288,10 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn ...@@ -287,10 +288,10 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
for(index_t x = 0; x < X; ++x) for(index_t x = 0; x < X; ++x)
{ {
blockwise_batch_gemm.Run(p_wei_block + wei_c_x_k_block_desc.Get1dIndex(0, x, 0), blockwise_batch_gemm.Run(
p_in_block + p_wei_block + wei_c_x_k_block_desc.GetOffsetFromMultiIndex(0, x, 0),
in_c_h_w_n_block_desc.Get1dIndex(0, 0, x, 0), p_in_block + in_c_h_w_n_block_desc.GetOffsetFromMultiIndex(0, 0, x, 0),
p_out_thread); p_out_thread);
} }
__syncthreads(); __syncthreads();
...@@ -336,16 +337,16 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn ...@@ -336,16 +337,16 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
} }
#endif #endif
threadwise_10d_tensor_copy( threadwise_10d_tensor_copy(out_10d_thread_desc,
out_10d_thread_desc, p_out_thread,
p_out_thread, out_10d_global_desc,
out_10d_global_desc, p_out_global +
p_out_global + out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(), out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{}); Number<OutThreadCopyDataPerWrite_N>{});
} }
}; };
...@@ -82,7 +82,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -82,7 +82,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
constexpr auto block_work_desc = make_ConstantTensorDescriptor( constexpr auto block_work_desc = make_ConstantTensorDescriptor(
Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{}); Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{});
const auto block_work_multi_id = block_work_desc.GetMultiIndex(get_block_1d_id()); const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock; const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock;
const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock; const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock;
...@@ -225,11 +226,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -225,11 +226,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
#if 1 #if 1
const Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + p_in_global +
in_c_h_w_n_global_desc.Get1dIndex( in_c_h_w_n_global_desc.GetOffsetFromMultiIndex(
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin); 0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
const Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); p_wei_global +
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin);
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0), p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0),
...@@ -240,13 +242,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -240,13 +242,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
#pragma unroll #pragma unroll
for(index_t x = 0; x < X; ++x) for(index_t x = 0; x < X; ++x)
{ {
blockwise_in_copy.Run(p_in_global_block_offset + blockwise_in_copy.Run(
in_c_h_w_n_global_desc.Get1dIndex(0, y, x, 0), p_in_global_block_offset +
p_in_block); in_c_h_w_n_global_desc.GetOffsetFromMultiIndex(0, y, x, 0),
p_in_block);
blockwise_wei_copy.Run(p_wei_global_block_offset + blockwise_wei_copy.Run(
wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, 0), p_wei_global_block_offset +
p_wei_block); wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0),
p_wei_block);
__syncthreads(); __syncthreads();
...@@ -263,11 +267,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -263,11 +267,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
{ {
const Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + p_in_global +
in_c_h_w_n_global_desc.Get1dIndex( in_c_h_w_n_global_desc.GetOffsetFromMultiIndex(
0, hi_block_data_begin + y, wi_block_data_begin + x, n_block_data_begin); 0, hi_block_data_begin + y, wi_block_data_begin + x, n_block_data_begin);
const Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, k_block_data_begin); p_wei_global +
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, k_block_data_begin);
for(index_t c_block_data_begin = 0; c_block_data_begin < C; for(index_t c_block_data_begin = 0; c_block_data_begin < C;
c_block_data_begin += CPerBlock, c_block_data_begin += CPerBlock,
...@@ -347,17 +352,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -347,17 +352,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
} }
#endif #endif
threadwise_tensor_slice_copy( threadwise_tensor_slice_copy(out_10d_thread_desc,
out_10d_thread_desc, p_out_thread,
p_out_thread, out_10d_global_desc,
out_10d_global_desc, p_out_global +
p_out_global + out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(), out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{}); Number<OutThreadCopyDataPerWrite_N>{});
}).else_([&](auto f_dummy) { }).else_([&](auto f_dummy) {
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0, GemmNPerThreadSubC % NPerThread == 0,
...@@ -397,17 +402,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -397,17 +402,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
} }
#endif #endif
threadwise_tensor_slice_copy( threadwise_tensor_slice_copy(out_10d_thread_desc,
out_10d_thread_desc, p_out_thread,
p_out_thread, out_10d_global_desc,
out_10d_global_desc, p_out_global +
p_out_global + out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(), out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{}); Number<OutThreadCopyDataPerWrite_N>{});
}); });
} }
}; };
...@@ -85,10 +85,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -85,10 +85,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock); constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock);
constexpr index_t NBlockWork = mod_conv::integer_divide_ceil(N, NPerBlock); constexpr index_t NBlockWork = mod_conv::integer_divide_ceil(N, NPerBlock);
constexpr auto block_work_desc = make_ConstantTensorDescriptor( constexpr auto block_work_desc = make_packed_ConstantTensorDescriptor(
Sequence<KBlockWork, HBlockWork, WBlockWork, NBlockWork>{}); Sequence<KBlockWork, HBlockWork, WBlockWork, NBlockWork>{});
const auto block_work_multi_id = block_work_desc.GetMultiIndex(get_block_1d_id()); const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t k_block_data_begin = block_work_multi_id[0] * KPerBlock; const index_t k_block_data_begin = block_work_multi_id[0] * KPerBlock;
const index_t ho_block_data_begin = block_work_multi_id[1] * HoPerBlock; const index_t ho_block_data_begin = block_work_multi_id[1] * HoPerBlock;
...@@ -108,7 +109,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -108,7 +109,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
GemmDataPerReadA, GemmDataPerReadA,
GemmDataPerReadB); GemmDataPerReadB);
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto in_c_h_w_n_block_desc = make_ranked_ConstantTensorDescriptor_with_alignment(
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{}, Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{},
Number<InBlockCopyDataPerRead_N>{}); Number<InBlockCopyDataPerRead_N>{});
...@@ -117,12 +118,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -117,12 +118,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0, static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
"GemmDataPerReadB alignment requirement is not meet"); "GemmDataPerReadB alignment requirement is not meet");
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto wei_c_k_block_desc = make_ranked_ConstantTensorDescriptor_with_alignment(
Sequence<CPerBlock, KPerBlock>{}, Sequence<CPerBlock, KPerBlock>{},
Number<mod_conv::max(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{}); Number<mod_conv::max(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
// tensor view of threadwise output in register // tensor view of threadwise output in register
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor( constexpr auto out_k_h_w_n_thread_desc = make_packed_ConstantTensorDescriptor(
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{}); Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
// blockwise copy // blockwise copy
...@@ -243,11 +244,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -243,11 +244,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
{ {
const Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + p_in_global +
in_c_h_w_n_global_desc.Get1dIndex( in_c_h_w_n_global_desc.GetOffsetFromMultiIndex(
0, hi_block_data_begin + y, wi_block_data_begin + x, n_block_data_begin); 0, hi_block_data_begin + y, wi_block_data_begin + x, n_block_data_begin);
const Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, k_block_data_begin); p_wei_global +
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, k_block_data_begin);
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
...@@ -399,17 +401,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -399,17 +401,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
} }
#endif #endif
threadwise_tensor_slice_copy( threadwise_tensor_slice_copy(out_10d_thread_desc,
out_10d_thread_desc, p_out_thread,
p_out_thread, out_10d_global_desc,
out_10d_global_desc, p_out_global +
p_out_global + out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(), out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{}); Number<OutThreadCopyDataPerWrite_N>{});
}).else_([&](auto fwd) { }).else_([&](auto fwd) {
static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0, GemmNPerThreadSubC % NPerThread == 0,
...@@ -450,17 +452,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn ...@@ -450,17 +452,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
} }
#endif #endif
threadwise_tensor_slice_copy( threadwise_tensor_slice_copy(out_10d_thread_desc,
out_10d_thread_desc, p_out_thread,
p_out_thread, out_10d_global_desc,
out_10d_global_desc, p_out_global +
p_out_global + out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(), out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{}); Number<OutThreadCopyDataPerWrite_N>{});
}); });
} }
}; };
...@@ -86,10 +86,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn ...@@ -86,10 +86,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
constexpr index_t HBlockWork = mod_conv::integer_divide_ceil(Ho, HoPerBlock); constexpr index_t HBlockWork = mod_conv::integer_divide_ceil(Ho, HoPerBlock);
constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock); constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock);
constexpr auto block_work_desc = make_ConstantTensorDescriptor( constexpr auto block_work_desc = make_packed_ConstantTensorDescriptor(
Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{}); Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{});
const auto block_work_multi_id = block_work_desc.GetMultiIndex(get_block_1d_id()); const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock; const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock;
const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock; const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock;
...@@ -101,7 +102,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn ...@@ -101,7 +102,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
// global tensor view // global tensor view
constexpr auto wei_c_k_global_desc = constexpr auto wei_c_k_global_desc =
make_ConstantTensorDescriptor(Sequence<C, K>{}, Sequence<Y * X * K, 1>{}); make_ranked_ConstantTensorDescriptor(Sequence<C, K>{}, Sequence<Y * X * K, 1>{});
// LDS tensor view // LDS tensor view
// be careful of alignment // be careful of alignment
...@@ -110,7 +111,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn ...@@ -110,7 +111,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
GemmDataPerReadA, GemmDataPerReadA,
GemmDataPerReadB); GemmDataPerReadB);
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto in_c_h_w_n_block_desc = make_ranked_ConstantTensorDescriptor_with_alignment(
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{}, Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{},
Number<InBlockReorderDataPerWrite_N>{}); Number<InBlockReorderDataPerWrite_N>{});
...@@ -119,12 +120,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn ...@@ -119,12 +120,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0, static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
"GemmDataPerReadB alignment requirement is not meet"); "GemmDataPerReadB alignment requirement is not meet");
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto wei_c_k_block_desc = make_ranked_ConstantTensorDescriptor_with_alignment(
Sequence<CPerBlock, KPerBlock>{}, Sequence<CPerBlock, KPerBlock>{},
Number<mod_conv::max(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{}); Number<mod_conv::max(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
// tensor view of threadwise output in register // tensor view of threadwise output in register
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor( constexpr auto out_k_h_w_n_thread_desc = make_packed_ConstantTensorDescriptor(
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{}); Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
// blockwise copy // blockwise copy
...@@ -241,11 +242,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn ...@@ -241,11 +242,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
{ {
const Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + p_in_global +
in_n_c_h_w_global_desc.Get1dIndex( in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(
n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin + x); n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin + x);
const Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, k_block_data_begin); p_wei_global +
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, k_block_data_begin);
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
...@@ -359,13 +361,13 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn ...@@ -359,13 +361,13 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto fwd) { // fwd do nothing but
// perfect forwarding. // perfect forwarding.
// Using this trick to // Using this trick to
// make this lambda a generic lambda, so it won't be compiled until // make this lambda a generic lambda, so it won't be compiled until
// instantiated // instantiated
static_assert( static_assert(
(f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0), (fwd(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
"wrong!"); "wrong!");
// output is a 10d tensor // output is a 10d tensor
...@@ -373,12 +375,13 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn ...@@ -373,12 +375,13 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
constexpr index_t N1 = NPerBlock / N2; constexpr index_t N1 = NPerBlock / N2;
constexpr index_t W2 = constexpr index_t W2 =
(GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC); (GemmNLevel0Cluster * GemmNLevel1Cluster) / fwd(NPerBlock / GemmNPerThreadSubC);
constexpr index_t W1 = WoPerBlock / W2; constexpr index_t W1 = WoPerBlock / W2;
constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread; constexpr index_t K1 = KPerBlock / KPerThread;
#if 0
constexpr auto out_10d_global_desc = constexpr auto out_10d_global_desc =
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2), make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
K1, K1,
...@@ -387,12 +390,23 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn ...@@ -387,12 +390,23 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
Wo / (W1 * W2), Wo / (W1 * W2),
W1, W1,
W2, W2,
N / f_dummy(N1 * N2), N / fwd(N1 * N2),
N1, N1,
N2>{}); N2>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{}); Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
#else
constexpr auto out_10d_global_desc = fwd(out_k_h_w_n_global_desc)
.Fold(I3, Number<N1>{}, Number<N2>{})
.Fold(I2, Number<W1>{}, Number<W2>{})
.Fold(I0, Number<K1>{}, Number<K2>{});
constexpr auto out_10d_thread_desc = fwd(out_k_h_w_n_thread_desc)
.Fold(I3, Number<1>{}, Number<N2>{})
.Fold(I2, Number<W1>{}, Number<1>{})
.Fold(I0, Number<1>{}, Number<K2>{});
#endif
#if 0 #if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
...@@ -407,19 +421,19 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn ...@@ -407,19 +421,19 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
} }
#endif #endif
threadwise_tensor_slice_copy( threadwise_tensor_slice_copy(out_10d_thread_desc,
out_10d_thread_desc, p_out_thread,
p_out_thread, out_10d_global_desc,
out_10d_global_desc, p_out_global +
p_out_global + out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(), out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{}); Number<OutThreadCopyDataPerWrite_N>{});
}).else_([&](auto f_dummy) { }).else_([&](auto fwd) {
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0, GemmNPerThreadSubC % NPerThread == 0,
"wrong!"); "wrong!");
...@@ -428,16 +442,30 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn ...@@ -428,16 +442,30 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock; constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster; constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3); constexpr index_t W1 = WoPerBlock / fwd(W2 * W3);
constexpr index_t K2 = GemmMPerThreadSubC; constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = KPerBlock / KPerThread; constexpr index_t K1 = KPerBlock / KPerThread;
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor( #if 0
constexpr auto out_10d_global_desc = make_packed_ConstantTensorDescriptor(
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3, N / N1, N1>{}); Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3, N / N1, N1>{});
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( constexpr auto out_10d_thread_desc = make_packed_ConstantTensorDescriptor(
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{}); Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
#else
constexpr auto out_10d_global_desc =
fwd(out_k_h_w_n_global_desc)
.Fold(I3, Number<N1>{})
.Fold(I2, Number<W1>{}, Number<W2>{}, Number<W3>{})
.Fold(I0, Number<K1>{}, Number<K2>{});
constexpr auto out_10d_thread_desc =
fwd(out_k_h_w_n_thread_desc)
.Fold(I3, Number<N1>{})
.Fold(I2, Number<W1>{}, Number<1>{}, Number<W3>{})
.Fold(I0, Number<1>{}, Number<K2>{});
#endif
#if 0 #if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
...@@ -457,17 +485,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn ...@@ -457,17 +485,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
} }
#endif #endif
threadwise_tensor_slice_copy( threadwise_tensor_slice_copy(out_10d_thread_desc,
out_10d_thread_desc, p_out_thread,
p_out_thread, out_10d_global_desc,
out_10d_global_desc, p_out_global +
p_out_global + out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(), out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{}); Number<OutThreadCopyDataPerWrite_N>{});
}); });
} }
}; };
...@@ -86,10 +86,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw ...@@ -86,10 +86,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
constexpr index_t HBlockWork = mod_conv::integer_divide_ceil(Ho, HoPerBlock); constexpr index_t HBlockWork = mod_conv::integer_divide_ceil(Ho, HoPerBlock);
constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock); constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock);
constexpr auto block_work_desc = make_ConstantTensorDescriptor( constexpr auto block_work_desc = make_packed_ConstantTensorDescriptor(
Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{}); Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{});
const auto block_work_multi_id = block_work_desc.GetMultiIndex(get_block_1d_id()); const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock; const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock;
const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock; const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock;
...@@ -109,7 +110,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw ...@@ -109,7 +110,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
GemmDataPerReadA, GemmDataPerReadA,
GemmDataPerReadB); GemmDataPerReadB);
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto in_c_h_w_n_block_desc = make_ranked_ConstantTensorDescriptor_with_alignment(
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{}, Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{},
Number<InBlockReorderDataPerWrite_N>{}); Number<InBlockReorderDataPerWrite_N>{});
...@@ -118,12 +119,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw ...@@ -118,12 +119,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0, static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
"GemmDataPerReadB alignment requirement is not meet"); "GemmDataPerReadB alignment requirement is not meet");
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto wei_c_k_block_desc = make_ranked_ConstantTensorDescriptor_with_alignment(
Sequence<CPerBlock, KPerBlock>{}, Sequence<CPerBlock, KPerBlock>{},
Number<mod_conv::max(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{}); Number<mod_conv::max(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
// tensor view of threadwise output in register // tensor view of threadwise output in register
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor( constexpr auto out_k_h_w_n_thread_desc = make_packed_ConstantTensorDescriptor(
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{}); Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
// blockwise copy // blockwise copy
...@@ -240,11 +241,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw ...@@ -240,11 +241,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
{ {
const Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + p_in_global +
in_n_c_h_w_global_desc.Get1dIndex( in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(
n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin + x); n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin + x);
const Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, k_block_data_begin); p_wei_global +
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, k_block_data_begin);
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
...@@ -407,10 +409,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw ...@@ -407,10 +409,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + p_out_global +
out_n_k_h_w_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, out_n_k_h_w_global_desc.GetOffsetFromMultiIndex(
k_block_data_begin + k_thread_data_begin, n_block_data_begin + n_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, k_block_data_begin + k_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin), ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin),
out_10d_thread_desc.GetLengths(), out_10d_thread_desc.GetLengths(),
map_out_global2thread); map_out_global2thread);
// Number<OutThreadCopyDataPerWrite_W>{}); // Number<OutThreadCopyDataPerWrite_W>{});
...@@ -461,10 +464,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw ...@@ -461,10 +464,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + p_out_global +
out_n_k_h_w_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, out_n_k_h_w_global_desc.GetOffsetFromMultiIndex(
k_block_data_begin + k_thread_data_begin, n_block_data_begin + n_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, k_block_data_begin + k_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin), ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin),
out_10d_thread_desc.GetLengths(), out_10d_thread_desc.GetLengths(),
map_out_global2thread); map_out_global2thread);
// Number<OutThreadCopyDataPerWrite_W>{}); // Number<OutThreadCopyDataPerWrite_W>{});
......
...@@ -236,11 +236,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn ...@@ -236,11 +236,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
#if 1 #if 1
const Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + p_in_global +
in_n_c_h_w_global_desc.Get1dIndex( in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(
n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin); n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin);
const Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); p_wei_global +
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin);
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1), p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
...@@ -251,23 +252,27 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn ...@@ -251,23 +252,27 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
for(index_t x = 0; x < X; ++x) for(index_t x = 0; x < X; ++x)
{ {
#if 1 #if 1
blockwise_in_copy_reorder.Run(p_in_global_block_offset + blockwise_in_copy_reorder.Run(
in_n_c_h_w_global_desc.Get1dIndex(0, 0, y, x), p_in_global_block_offset +
p_in_block); in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x),
p_in_block);
blockwise_wei_copy.Run(p_wei_global_block_offset +
wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, 0), blockwise_wei_copy.Run(
p_wei_block); p_wei_global_block_offset +
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0),
p_wei_block);
#else #else
Float p_in_clipboard[blockwise_in_copy_reorder.GetRegisterClipboardSize()]; Float p_in_clipboard[blockwise_in_copy_reorder.GetRegisterClipboardSize()];
Float p_wei_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; Float p_wei_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
blockwise_in_copy_reorder.RunLoadRegisterClipboard( blockwise_in_copy_reorder.RunLoadRegisterClipboard(
p_in_global_block_offset + in_n_c_h_w_global_desc.Get1dIndex(0, 0, y, x), p_in_global_block_offset +
in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x),
p_in_clipboard); p_in_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard( blockwise_wei_copy.RunLoadRegisterClipboard(
p_wei_global_block_offset + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, 0), p_wei_global_block_offset +
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0),
p_wei_clipboard); p_wei_clipboard);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_clipboard, p_wei_block); blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_clipboard, p_wei_block);
...@@ -291,11 +296,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn ...@@ -291,11 +296,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
{ {
const Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + p_in_global +
in_n_c_h_w_global_desc.Get1dIndex( in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(
n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin + x); n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin + x);
const Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, k_block_data_begin); p_wei_global +
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, k_block_data_begin);
for(index_t c_block_data_begin = 0; c_block_data_begin < C; for(index_t c_block_data_begin = 0; c_block_data_begin < C;
c_block_data_begin += CPerBlock, c_block_data_begin += CPerBlock,
...@@ -390,17 +396,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn ...@@ -390,17 +396,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
} }
#endif #endif
threadwise_tensor_slice_copy( threadwise_tensor_slice_copy(out_10d_thread_desc,
out_10d_thread_desc, p_out_thread,
p_out_thread, out_10d_global_desc,
out_10d_global_desc, p_out_global +
p_out_global + out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(), out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{}); Number<OutThreadCopyDataPerWrite_N>{});
}).else_([&](auto f_dummy) { }).else_([&](auto f_dummy) {
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
GemmNPerThreadSubC % NPerThread == 0, GemmNPerThreadSubC % NPerThread == 0,
...@@ -440,17 +446,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn ...@@ -440,17 +446,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
} }
#endif #endif
threadwise_tensor_slice_copy( threadwise_tensor_slice_copy(out_10d_thread_desc,
out_10d_thread_desc, p_out_thread,
p_out_thread, out_10d_global_desc,
out_10d_global_desc, p_out_global +
p_out_global + out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
out_10d_thread_desc.GetLengths(), out_10d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite_N>{}); Number<OutThreadCopyDataPerWrite_N>{});
}); });
} }
}; };
...@@ -86,7 +86,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw ...@@ -86,7 +86,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
constexpr auto block_work_desc = make_ConstantTensorDescriptor( constexpr auto block_work_desc = make_ConstantTensorDescriptor(
Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{}); Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{});
const auto block_work_multi_id = block_work_desc.GetMultiIndex(get_block_1d_id()); const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock; const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock;
const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock; const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock;
...@@ -234,11 +235,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw ...@@ -234,11 +235,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
#if 0 #if 0
const Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + p_in_global +
in_n_c_h_w_global_desc.Get1dIndex( in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(
n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin); n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin);
const Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); p_wei_global + wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin);
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1), p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
...@@ -250,22 +251,22 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw ...@@ -250,22 +251,22 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
{ {
#if 1 #if 1
blockwise_in_copy_reorder.Run(p_in_global_block_offset + blockwise_in_copy_reorder.Run(p_in_global_block_offset +
in_n_c_h_w_global_desc.Get1dIndex(0, 0, y, x), in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x),
p_in_block); p_in_block);
blockwise_wei_copy.Run(p_wei_global_block_offset + blockwise_wei_copy.Run(p_wei_global_block_offset +
wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, 0), wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0),
p_wei_block); p_wei_block);
#else #else
Float p_in_clipboard[blockwise_in_copy_reorder.GetRegisterClipboardSize()]; Float p_in_clipboard[blockwise_in_copy_reorder.GetRegisterClipboardSize()];
Float p_wei_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; Float p_wei_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
blockwise_in_copy_reorder.RunLoadRegisterClipboard( blockwise_in_copy_reorder.RunLoadRegisterClipboard(
p_in_global_block_offset + in_n_c_h_w_global_desc.Get1dIndex(0, 0, y, x), p_in_global_block_offset + in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x),
p_in_clipboard); p_in_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard( blockwise_wei_copy.RunLoadRegisterClipboard(
p_wei_global_block_offset + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, 0), p_wei_global_block_offset + wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0),
p_wei_clipboard); p_wei_clipboard);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_clipboard, p_wei_block); blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_clipboard, p_wei_block);
...@@ -289,11 +290,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw ...@@ -289,11 +290,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
{ {
const Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + p_in_global +
in_n_c_h_w_global_desc.Get1dIndex( in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(
n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin + x); n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin + x);
const Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, k_block_data_begin); p_wei_global +
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, k_block_data_begin);
for(index_t c_block_data_begin = 0; c_block_data_begin < C; for(index_t c_block_data_begin = 0; c_block_data_begin < C;
c_block_data_begin += CPerBlock, c_block_data_begin += CPerBlock,
...@@ -395,10 +397,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw ...@@ -395,10 +397,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + p_out_global +
out_n_k_h_w_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, out_n_k_h_w_global_desc.GetOffsetFromMultiIndex(
k_block_data_begin + k_thread_data_begin, n_block_data_begin + n_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, k_block_data_begin + k_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin), ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin),
out_10d_thread_desc.GetLengths(), out_10d_thread_desc.GetLengths(),
map_out_global2thread); map_out_global2thread);
// Number<OutThreadCopyDataPerWrite_W>{}); // Number<OutThreadCopyDataPerWrite_W>{});
...@@ -444,10 +447,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw ...@@ -444,10 +447,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + p_out_global +
out_n_k_h_w_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, out_n_k_h_w_global_desc.GetOffsetFromMultiIndex(
k_block_data_begin + k_thread_data_begin, n_block_data_begin + n_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, k_block_data_begin + k_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin), ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin),
out_10d_thread_desc.GetLengths(), out_10d_thread_desc.GetLengths(),
map_out_global2thread); map_out_global2thread);
// Number<OutThreadCopyDataPerWrite_W>{}); // Number<OutThreadCopyDataPerWrite_W>{});
......
...@@ -193,10 +193,11 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn ...@@ -193,10 +193,11 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
__shared__ Float p_wei_block[wei_block_space]; __shared__ Float p_wei_block[wei_block_space];
const Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin); p_in_global + in_cb_global_desc.GetOffsetFromMultiIndex(0, b_block_data_begin);
const Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); p_wei_global +
wei_cyxk_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin);
// register // register
Float p_out_thread[out_kb_thread_desc.GetElementSpace()]; Float p_out_thread[out_kb_thread_desc.GetElementSpace()];
...@@ -236,7 +237,7 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn ...@@ -236,7 +237,7 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
#elif 1 #elif 1
blockwise_gemm.Run_asm blockwise_gemm.Run_asm
#endif #endif
(p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), (p_wei_block + wei_cyxk_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
p_in_block + y * Wi + x, p_in_block + y * Wi + x,
p_out_thread); p_out_thread);
} }
...@@ -267,8 +268,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn ...@@ -267,8 +268,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
if(n_data < N && h_data < Ho && w_data < Wo) if(n_data < N && h_data < Ho && w_data < Wo)
{ {
p_out_global[out_khwn_global_desc.Get1dIndex(k_data, h_data, w_data, n_data)] = p_out_global[out_khwn_global_desc.GetOffsetFromMultiIndex(
p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]; k_data, h_data, w_data, n_data)] =
p_out_thread[out_kb_thread_desc.GetOffsetFromMultiIndex(k, b)];
} }
} }
} }
......
...@@ -198,10 +198,11 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -198,10 +198,11 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
__shared__ Float p_wei_block_double[2 * wei_block_space]; __shared__ Float p_wei_block_double[2 * wei_block_space];
const Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin); p_in_global + in_cb_global_desc.GetOffsetFromMultiIndex(0, b_block_data_begin);
const Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); p_wei_global +
wei_cyxk_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin);
// preload data into LDS // preload data into LDS
{ {
...@@ -269,7 +270,8 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -269,7 +270,8 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
#elif 0 #elif 0
blockwise_gemm.Run_asm blockwise_gemm.Run_asm
#endif #endif
(p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), (p_wei_block_now +
wei_cyxk_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
p_in_block_now + y * Wi + x, p_in_block_now + y * Wi + x,
p_out_thread); p_out_thread);
} }
...@@ -310,7 +312,8 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -310,7 +312,8 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
#elif 0 #elif 0
blockwise_gemm.Run_asm blockwise_gemm.Run_asm
#endif #endif
(p_wei_block_double + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), (p_wei_block_double +
wei_cyxk_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
p_in_block_double + y * Wi + x, p_in_block_double + y * Wi + x,
p_out_thread); p_out_thread);
} }
...@@ -336,7 +339,7 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -336,7 +339,7 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
blockwise_gemm.Run_asm blockwise_gemm.Run_asm
#endif #endif
(p_wei_block_double + wei_block_space + (p_wei_block_double + wei_block_space +
wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), wei_cyxk_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
p_in_block_double + in_block_space + y * Wi + x, p_in_block_double + in_block_space + y * Wi + x,
p_out_thread); p_out_thread);
} }
...@@ -365,14 +368,14 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -365,14 +368,14 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
constexpr auto out_kb_global_desc = make_ConstantTensorDescriptor(Sequence<K, B>{}); constexpr auto out_kb_global_desc = make_ConstantTensorDescriptor(Sequence<K, B>{});
threadwise_6d_tensor_copy( threadwise_6d_tensor_copy(out_6d_thread_desc,
out_6d_thread_desc, p_out_thread,
p_out_thread, out_6d_global_desc,
out_6d_global_desc, p_out_global +
p_out_global + out_kb_global_desc.GetOffsetFromMultiIndex(
out_kb_global_desc.Get1dIndex(k_thread_data_begin, b_thread_data_begin), k_thread_data_begin, b_thread_data_begin),
out_6d_thread_desc.GetLengths(), out_6d_thread_desc.GetLengths(),
Number<OutThreadCopyDataPerWrite>{}); Number<OutThreadCopyDataPerWrite>{});
} }
else else
{ {
...@@ -393,9 +396,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer ...@@ -393,9 +396,9 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
if(n_data < N && h_data < Ho && w_data < Wo) if(n_data < N && h_data < Ho && w_data < Wo)
{ {
p_out_global[out_khwn_global_desc.Get1dIndex( p_out_global[out_khwn_global_desc.GetOffsetFromMultiIndex(
k_data, h_data, w_data, n_data)] = k_data, h_data, w_data, n_data)] =
p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]; p_out_thread[out_kb_thread_desc.GetOffsetFromMultiIndex(k, b)];
} }
} }
} }
......
...@@ -83,7 +83,8 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw ...@@ -83,7 +83,8 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
constexpr auto block_work_desc = constexpr auto block_work_desc =
make_ConstantTensorDescriptor(Sequence<KBlockWork, BBlockWork>{}); make_ConstantTensorDescriptor(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_multi_id = block_work_desc.GetMultiIndex(get_block_1d_id()); const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock; const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock; const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
...@@ -219,10 +220,10 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw ...@@ -219,10 +220,10 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
{ {
// calculate origin of block input and weight tensor on global memory // calculate origin of block input and weight tensor on global memory
const Float* p_in_block_on_global = const Float* p_in_block_on_global =
p_in_global + in_n_c_h_w_global_desc.Get1dIndex(0, 0, y, x); p_in_global + in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x);
const Float* p_wei_block_on_global = const Float* p_wei_block_on_global =
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, 0); p_wei_global + wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0);
for(index_t for(index_t
c_block_data_on_global = 0; c_block_data_on_global = 0;
...@@ -285,7 +286,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw ...@@ -285,7 +286,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
// origin of thread tensor in global memory // origin of thread tensor in global memory
const index_t p_out_thread_on_global = const index_t p_out_thread_on_global =
p_out_global + p_out_global +
out_k_n1_b_n2_global_merged_desc.Get1dIndex( out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
k_thread_data_on_global, 0, 0, 0); // dst origin on merged global tensor k_thread_data_on_global, 0, 0, 0); // dst origin on merged global tensor
// copy // copy
......
...@@ -190,18 +190,19 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -190,18 +190,19 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
c_block_data_begin += CPerBlock, __syncthreads()) c_block_data_begin += CPerBlock, __syncthreads())
{ {
// copy input tensor to LDS // copy input tensor to LDS
blockwise_in_copy.Run(p_in_vec_global + blockwise_in_copy.Run(
in_nchw_vec_global_desc.Get1dIndex(n_block_data_begin, p_in_vec_global +
c_block_data_begin, in_nchw_vec_global_desc.GetOffsetFromMultiIndex(n_block_data_begin,
hi_block_data_begin, c_block_data_begin,
wi_block_data_begin), hi_block_data_begin,
p_in_vec_block); wi_block_data_begin),
p_in_vec_block);
// copy weight tensor to LDS // copy weight tensor to LDS
blockwise_wei_copy.Run( blockwise_wei_copy.Run(p_wei_vec_global +
p_wei_vec_global + wei_kcyx_vec_global_desc.GetOffsetFromMultiIndex(
wei_kcyx_vec_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0), k_block_data_begin, c_block_data_begin, 0, 0),
p_wei_vec_block); p_wei_vec_block);
__syncthreads(); __syncthreads();
...@@ -212,26 +213,28 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -212,26 +213,28 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
threadwise_direct_convolution_2( threadwise_direct_convolution_2(
in_nchw_vec_thread_block_desc, in_nchw_vec_thread_block_desc,
p_in_vec_block + p_in_vec_block +
in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin, in_nchw_vec_block_desc.GetOffsetFromMultiIndex(n_thread_data_begin,
c_thread_data, c_thread_data,
hi_thread_data_begin, hi_thread_data_begin,
wi_thread_data_begin), wi_thread_data_begin),
wei_kcyx_vec_thread_block_desc, wei_kcyx_vec_thread_block_desc,
p_wei_vec_block + p_wei_vec_block +
wei_kcyx_vec_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0), wei_kcyx_vec_block_desc.GetOffsetFromMultiIndex(
k_thread_data_begin, c_thread_data, 0, 0),
out_nkhw_thread_desc, out_nkhw_thread_desc,
p_out_thread); p_out_thread);
#elif 0 #elif 0
threadwise_direct_convolution_3( threadwise_direct_convolution_3(
in_nchw_vec_thread_block_desc, in_nchw_vec_thread_block_desc,
p_in_vec_block + p_in_vec_block +
in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin, in_nchw_vec_block_desc.GetOffsetFromMultiIndex(n_thread_data_begin,
c_thread_data, c_thread_data,
hi_thread_data_begin, hi_thread_data_begin,
wi_thread_data_begin), wi_thread_data_begin),
wei_kcyx_vec_thread_block_desc, wei_kcyx_vec_thread_block_desc,
p_wei_vec_block + p_wei_vec_block +
wei_kcyx_vec_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0), wei_kcyx_vec_block_desc.GetOffsetFromMultiIndex(
k_thread_data_begin, c_thread_data, 0, 0),
out_nkhw_thread_desc, out_nkhw_thread_desc,
p_out_thread); p_out_thread);
#endif #endif
...@@ -239,14 +242,14 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -239,14 +242,14 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
} }
// copy output tensor from register to global mem // copy output tensor from register to global mem
threadwise_4d_tensor_copy( threadwise_4d_tensor_copy(out_nkhw_thread_desc,
out_nkhw_thread_desc, p_out_thread,
p_out_thread, out_nkhw_global_desc,
out_nkhw_global_desc, p_out_global +
p_out_global + out_nkhw_global_desc.GetOffsetFromMultiIndex(
out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, n_block_data_begin + n_thread_data_begin,
k_block_data_begin + k_thread_data_begin, k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin), wo_block_data_begin + wo_thread_data_begin),
out_nkhw_thread_desc.GetLengths()); out_nkhw_thread_desc.GetLengths());
} }
...@@ -217,7 +217,7 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded( ...@@ -217,7 +217,7 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread); threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread);
const Float* p_wei_global_block_begin = const Float* p_wei_global_block_begin =
p_wei_global + wei_ek_global_desc.Get1dIndex(0, k_block_data_begin); p_wei_global + wei_ek_global_desc.GetOffsetFromMultiIndex(0, k_block_data_begin);
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_wei_global_block_begin += CPerBlock * wei_ek_global_desc.GetStride(I0), p_wei_global_block_begin += CPerBlock * wei_ek_global_desc.GetStride(I0),
...@@ -251,10 +251,11 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded( ...@@ -251,10 +251,11 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
{ {
auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
blockwise_batch_gemm.Run(p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), blockwise_batch_gemm.Run(
p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0), p_wei_block + wei_cyxk_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
p_out_thread, p_in_block + in_chwn_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
f_accum); p_out_thread,
f_accum);
} }
} }
} }
...@@ -284,10 +285,10 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded( ...@@ -284,10 +285,10 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
p_out_thread, p_out_thread,
out_khwn_global_desc, out_khwn_global_desc,
p_out_global + p_out_global +
out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, out_khwn_global_desc.GetOffsetFromMultiIndex(k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
out_hkwn_thread_desc.GetLengths(), out_hkwn_thread_desc.GetLengths(),
reorder_khwn_from_hkwn); reorder_khwn_from_hkwn);
} }
...@@ -93,7 +93,7 @@ struct TensorDescriptor ...@@ -93,7 +93,7 @@ struct TensorDescriptor
const std::vector<std::size_t>& GetStrides() const; const std::vector<std::size_t>& GetStrides() const;
template <class... Is> template <class... Is>
std::size_t Get1dIndex(Is... is) const std::size_t GetOffsetFromMultiIndex(Is... is) const
{ {
assert(sizeof...(Is) == this->GetNumOfDimension()); assert(sizeof...(Is) == this->GetNumOfDimension());
std::initializer_list<std::size_t> iss{static_cast<std::size_t>(is)...}; std::initializer_list<std::size_t> iss{static_cast<std::size_t>(is)...};
...@@ -246,13 +246,13 @@ struct Tensor ...@@ -246,13 +246,13 @@ struct Tensor
template <class... Is> template <class... Is>
T& operator()(Is... is) T& operator()(Is... is)
{ {
return mData[mDesc.Get1dIndex(is...)]; return mData[mDesc.GetOffsetFromMultiIndex(is...)];
} }
template <class... Is> template <class... Is>
const T& operator()(Is... is) const const T& operator()(Is... is) const
{ {
return mData[mDesc.Get1dIndex(is...)]; return mData[mDesc.GetOffsetFromMultiIndex(is...)];
} }
typename std::vector<T>::iterator begin() { return mData.begin(); } typename std::vector<T>::iterator begin() { return mData.begin(); }
......
...@@ -20,7 +20,7 @@ __device__ void threadwise_2d_tensor_pointwise_operation_unary(Desc, Float* __re ...@@ -20,7 +20,7 @@ __device__ void threadwise_2d_tensor_pointwise_operation_unary(Desc, Float* __re
{ {
for(index_t did1 = 0; did1 < desc.GetLength(I1); ++did1) for(index_t did1 = 0; did1 < desc.GetLength(I1); ++did1)
{ {
const index_t dindex = desc.Get1dIndex(did0, did1); const index_t dindex = desc.GetOffsetFromMultiIndex(did0, did1);
f(p[dindex]); f(p[dindex]);
} }
...@@ -53,11 +53,11 @@ __device__ void threadwise_2d_tensor_pointwise_operation_binary_reorder_by_get_d ...@@ -53,11 +53,11 @@ __device__ void threadwise_2d_tensor_pointwise_operation_binary_reorder_by_get_d
{ {
for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1) for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
{ {
const index_t aindex = src_desc.Get1dIndex(did0, did1); const index_t aindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
const index_t did[2] = {did0, did1}; const index_t did[2] = {did0, did1};
const index_t bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]); const index_t bindex = dst_desc.GetOffsetFromMultiIndex(did[IR0], did[IR1]);
f(p_src[aindex], p_dst[bindex]); f(p_src[aindex], p_dst[bindex]);
} }
...@@ -127,7 +127,7 @@ __device__ void threadwise_2d_tensor_shift_down(Desc, Float* __restrict__ p, IDi ...@@ -127,7 +127,7 @@ __device__ void threadwise_2d_tensor_shift_down(Desc, Float* __restrict__ p, IDi
{ {
for(index_t did1 = 0; did1 < did1_end; ++did1) for(index_t did1 = 0; did1 < did1_end; ++did1)
{ {
const index_t dindex = desc.Get1dIndex(did0, did1); const index_t dindex = desc.GetOffsetFromMultiIndex(did0, did1);
const index_t sindex = dindex + nshift * desc.GetStride(IDim{}); const index_t sindex = dindex + nshift * desc.GetStride(IDim{});
......
...@@ -26,7 +26,7 @@ __device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __re ...@@ -26,7 +26,7 @@ __device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __re
{ {
for(index_t did3 = 0; did3 < desc.GetLength(I3); ++did3) for(index_t did3 = 0; did3 < desc.GetLength(I3); ++did3)
{ {
const index_t dindex = desc.Get1dIndex(did0, did1, did2, did3); const index_t dindex = desc.GetOffsetFromMultiIndex(did0, did1, did2, did3);
f(p[dindex]); f(p[dindex]);
} }
...@@ -75,12 +75,12 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_given_ds ...@@ -75,12 +75,12 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_given_ds
{ {
for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3) for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
{ {
const index_t aindex = src_desc.Get1dIndex(did0, did1, did2, did3); const index_t aindex = src_desc.GetOffsetFromMultiIndex(did0, did1, did2, did3);
const index_t did[4] = {did0, did1, did2, did3}; const index_t did[4] = {did0, did1, did2, did3};
const index_t bindex = const index_t bindex =
dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]); dst_desc.GetOffsetFromMultiIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
f(p_src[aindex], p_dst[bindex]); f(p_src[aindex], p_dst[bindex]);
...@@ -178,7 +178,7 @@ __device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDi ...@@ -178,7 +178,7 @@ __device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDi
{ {
for(index_t did3 = 0; did3 < did3_end; ++did3) for(index_t did3 = 0; did3 < did3_end; ++did3)
{ {
const index_t dindex = desc.Get1dIndex(did0, did1, did2, did3); const index_t dindex = desc.GetOffsetFromMultiIndex(did0, did1, did2, did3);
const index_t sindex = dindex + nshift * desc.GetStride(IDim{}); const index_t sindex = dindex + nshift * desc.GetStride(IDim{});
......
...@@ -46,11 +46,14 @@ __device__ void threadwise_direct_convolution_1(InDesc, ...@@ -46,11 +46,14 @@ __device__ void threadwise_direct_convolution_1(InDesc,
const index_t hi = ho + y; const index_t hi = ho + y;
const index_t wi = wo + x; const index_t wi = wo + x;
const index_t in_index = in_desc.Get1dIndex(n, c, hi, wi); const index_t in_index =
in_desc.GetOffsetFromMultiIndex(n, c, hi, wi);
const index_t wei_index = wei_desc.Get1dIndex(k, c, y, x); const index_t wei_index =
wei_desc.GetOffsetFromMultiIndex(k, c, y, x);
const index_t out_index = out_desc.Get1dIndex(n, k, ho, wo); const index_t out_index =
out_desc.GetOffsetFromMultiIndex(n, k, ho, wo);
fused_multiply_accumulate( fused_multiply_accumulate(
p_out[out_index], p_wei[wei_index], p_in[in_index]); p_out[out_index], p_wei[wei_index], p_in[in_index]);
...@@ -143,14 +146,14 @@ __device__ void threadwise_direct_convolution_3(InDesc, ...@@ -143,14 +146,14 @@ __device__ void threadwise_direct_convolution_3(InDesc,
{ {
// read first input // read first input
threadwise_4d_tensor_copy(in_desc, threadwise_4d_tensor_copy(in_desc,
p_in + in_desc.Get1dIndex(0, 0, y, 0), p_in + in_desc.GetOffsetFromMultiIndex(0, 0, y, 0),
in_reg_desc, in_reg_desc,
p_in_reg, p_in_reg,
in_reg_desc.GetLengths()); in_reg_desc.GetLengths());
// read first 1x1 weight // read first 1x1 weight
threadwise_4d_tensor_copy(wei_desc, threadwise_4d_tensor_copy(wei_desc,
p_wei + wei_desc.Get1dIndex(0, 0, y, 0), p_wei + wei_desc.GetOffsetFromMultiIndex(0, 0, y, 0),
wei_reg_desc, wei_reg_desc,
p_wei_reg, p_wei_reg,
wei_reg_desc.GetLengths()); wei_reg_desc.GetLengths());
...@@ -164,7 +167,7 @@ __device__ void threadwise_direct_convolution_3(InDesc, ...@@ -164,7 +167,7 @@ __device__ void threadwise_direct_convolution_3(InDesc,
{ {
// read new weight // read new weight
threadwise_4d_tensor_copy(wei_desc, threadwise_4d_tensor_copy(wei_desc,
p_wei + wei_desc.Get1dIndex(0, 0, y, x), p_wei + wei_desc.GetOffsetFromMultiIndex(0, 0, y, x),
wei_reg_desc, wei_reg_desc,
p_wei_reg, p_wei_reg,
wei_reg_desc.GetLengths()); wei_reg_desc.GetLengths());
...@@ -175,10 +178,10 @@ __device__ void threadwise_direct_convolution_3(InDesc, ...@@ -175,10 +178,10 @@ __device__ void threadwise_direct_convolution_3(InDesc,
// read new input // read new input
threadwise_4d_tensor_copy( threadwise_4d_tensor_copy(
in_desc, in_desc,
p_in + in_desc.Get1dIndex(0, 0, y, x + in_reg_desc.GetLength(I3) - 1), p_in + in_desc.GetOffsetFromMultiIndex(0, 0, y, x + in_reg_desc.GetLength(I3) - 1),
in_reg_desc, in_reg_desc,
p_in_reg + p_in_reg +
in_reg_desc.Get1dIndex(0, 0, 0, in_reg_desc.GetLength(I3) - in_w_new_read), in_reg_desc.GetOffsetFromMultiIndex(0, 0, 0, in_reg_desc.GetLength(I3) - in_w_new_read),
in_desc_reg_new_read.GetLengths()); in_desc_reg_new_read.GetLengths());
// do 1x1 conv // do 1x1 conv
...@@ -196,14 +199,14 @@ __device__ void threadwise_direct_convolution_3(InDesc, ...@@ -196,14 +199,14 @@ __device__ void threadwise_direct_convolution_3(InDesc,
{ {
// read new weight // read new weight
threadwise_4d_tensor_copy(wei_desc, threadwise_4d_tensor_copy(wei_desc,
p_wei + wei_desc.Get1dIndex(0, 0, y, x), p_wei + wei_desc.GetOffsetFromMultiIndex(0, 0, y, x),
wei_reg_desc, wei_reg_desc,
p_wei_reg, p_wei_reg,
wei_reg_desc.GetLengths()); wei_reg_desc.GetLengths());
// read new input // read new input
threadwise_4d_tensor_copy(in_desc, threadwise_4d_tensor_copy(in_desc,
p_in + in_desc.Get1dIndex(0, 0, y, x), p_in + in_desc.GetOffsetFromMultiIndex(0, 0, y, x),
in_reg_desc, in_reg_desc,
p_in_reg, p_in_reg,
in_reg_desc.GetLengths()); in_reg_desc.GetLengths());
......
...@@ -9,7 +9,7 @@ __device__ void threadwise_matrix_set_zero(Matrix, Float* __restrict__ p_thread) ...@@ -9,7 +9,7 @@ __device__ void threadwise_matrix_set_zero(Matrix, Float* __restrict__ p_thread)
{ {
for(index_t j = 0; j < Matrix::NCol(); ++j) for(index_t j = 0; j < Matrix::NCol(); ++j)
{ {
const index_t id = Matrix::Get1dIndex(i, j); const index_t id = Matrix::GetOffsetFromMultiIndex(i, j);
p_thread[id] = 0; p_thread[id] = 0;
} }
} }
...@@ -39,8 +39,8 @@ __device__ void threadwise_matrix_copy(SrcMatrix, ...@@ -39,8 +39,8 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
{ {
for(index_t j = 0; j < NCol; j += DataPerRead) for(index_t j = 0; j < NCol; j += DataPerRead)
{ {
const index_t src_index = src_mtx.Get1dIndex(i, j); const index_t src_index = src_mtx.GetOffsetFromMultiIndex(i, j);
const index_t dst_index = dst_mtx.Get1dIndex(i, j); const index_t dst_index = dst_mtx.GetOffsetFromMultiIndex(i, j);
*reinterpret_cast<vector_t*>(&p_dst[dst_index]) = *reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index]); *reinterpret_cast<const vector_t*>(&p_src[src_index]);
...@@ -83,9 +83,9 @@ __device__ void threadwise_gemm(MatrixA, ...@@ -83,9 +83,9 @@ __device__ void threadwise_gemm(MatrixA,
{ {
for(index_t j = 0; j < N; ++j) for(index_t j = 0; j < N; ++j)
{ {
const index_t aindex = a_mtx.Get1dIndex(k, i); // A is transposed const index_t aindex = a_mtx.GetOffsetFromMultiIndex(k, i); // A is transposed
const index_t bindex = b_mtx.Get1dIndex(k, j); const index_t bindex = b_mtx.GetOffsetFromMultiIndex(k, j);
const index_t cindex = c_mtx.Get1dIndex(i, j); const index_t cindex = c_mtx.GetOffsetFromMultiIndex(i, j);
p_c_thread[cindex] += p_a_thread[aindex] * p_b_thread[bindex]; p_c_thread[cindex] += p_a_thread[aindex] * p_b_thread[bindex];
} }
......
...@@ -19,7 +19,7 @@ __device__ void threadwise_tensor_slice_copy(SrcDesc, ...@@ -19,7 +19,7 @@ __device__ void threadwise_tensor_slice_copy(SrcDesc,
constexpr auto src_desc = SrcDesc{}; constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{}; constexpr auto dst_desc = DstDesc{};
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{}); constexpr auto ref_desc = make_packed_ConstantTensorDescriptor(SrcOpLengths{});
#if 0 #if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
...@@ -53,9 +53,9 @@ __device__ void threadwise_tensor_slice_copy(SrcDesc, ...@@ -53,9 +53,9 @@ __device__ void threadwise_tensor_slice_copy(SrcDesc,
static_for<0, nRead, 1>{}([&](auto IRead) { static_for<0, nRead, 1>{}([&](auto IRead) {
constexpr auto multi_id = decltype(Ids){}.PushBack(Number<IRead.Get() * DataPerRead>{}); constexpr auto multi_id = decltype(Ids){}.PushBack(Number<IRead.Get() * DataPerRead>{});
const index_t src_index = src_desc.Get1dIndex(multi_id); const index_t src_index = src_desc.GetOffsetFromMultiIndex(multi_id);
const index_t dst_index = dst_desc.Get1dIndex(multi_id); const index_t dst_index = dst_desc.GetOffsetFromMultiIndex(multi_id);
*(reinterpret_cast<vector_t*>(&p_dst[dst_index])) = *(reinterpret_cast<vector_t*>(&p_dst[dst_index])) =
*(reinterpret_cast<const vector_t*>(&p_src[src_index])); *(reinterpret_cast<const vector_t*>(&p_src[src_index]));
...@@ -84,9 +84,9 @@ threadwise_tensor_slice_copy_reorder_given_dst2src_v1(SrcDesc, ...@@ -84,9 +84,9 @@ threadwise_tensor_slice_copy_reorder_given_dst2src_v1(SrcDesc,
ford<SrcOpLengths>{}([&](auto src_multi_id) { ford<SrcOpLengths>{}([&](auto src_multi_id) {
const auto dst_multi_id = reorder_array_given_new2old(src_multi_id, MapDst2Src{}); const auto dst_multi_id = reorder_array_given_new2old(src_multi_id, MapDst2Src{});
const index_t dst_index = dst_desc.Get1dIndex(dst_multi_id); const index_t dst_index = dst_desc.GetOffsetFromMultiIndex(dst_multi_id);
const index_t src_index = src_desc.Get1dIndex(src_multi_id); const index_t src_index = src_desc.GetOffsetFromMultiIndex(src_multi_id);
p_dst[dst_index] = p_src[src_index]; p_dst[dst_index] = p_src[src_index];
}); });
...@@ -115,9 +115,9 @@ threadwise_tensor_slice_copy_reorder_given_dst2src_v2(SrcDesc, ...@@ -115,9 +115,9 @@ threadwise_tensor_slice_copy_reorder_given_dst2src_v2(SrcDesc,
ford<decltype(dst_op_lengths)>{}([&](auto dst_multi_id) { ford<decltype(dst_op_lengths)>{}([&](auto dst_multi_id) {
const auto src_multi_id = reorder_array_given_old2new(dst_multi_id, MapDst2Src{}); const auto src_multi_id = reorder_array_given_old2new(dst_multi_id, MapDst2Src{});
const index_t dst_index = dst_desc.Get1dIndex(dst_multi_id); const index_t dst_index = dst_desc.GetOffsetFromMultiIndex(dst_multi_id);
const index_t src_index = src_desc.Get1dIndex(src_multi_id); const index_t src_index = src_desc.GetOffsetFromMultiIndex(src_multi_id);
p_dst[dst_index] = p_src[src_index]; p_dst[dst_index] = p_src[src_index];
}); });
...@@ -177,7 +177,7 @@ threadwise_tensor_slice_copy_reorder_given_dst2src_v3(SrcDesc, ...@@ -177,7 +177,7 @@ threadwise_tensor_slice_copy_reorder_given_dst2src_v3(SrcDesc,
const auto src_multi_id = reorder_array_given_old2new(dst_multi_id, MapDst2Src{}); const auto src_multi_id = reorder_array_given_old2new(dst_multi_id, MapDst2Src{});
const index_t src_index = src_desc.Get1dIndex(src_multi_id); const index_t src_index = src_desc.GetOffsetFromMultiIndex(src_multi_id);
vector_type<Float, DstDataPerWrite>::SetScalar( vector_type<Float, DstDataPerWrite>::SetScalar(
dst_vec_data, p_src[src_index], IDstData); dst_vec_data, p_src[src_index], IDstData);
...@@ -186,7 +186,7 @@ threadwise_tensor_slice_copy_reorder_given_dst2src_v3(SrcDesc, ...@@ -186,7 +186,7 @@ threadwise_tensor_slice_copy_reorder_given_dst2src_v3(SrcDesc,
// write data // write data
const auto dst_multi_id = ids.PushBack(IWrite.Get() * DstDataPerWrite); const auto dst_multi_id = ids.PushBack(IWrite.Get() * DstDataPerWrite);
const index_t dst_index = dst_desc.Get1dIndex(dst_multi_id); const index_t dst_index = dst_desc.GetOffsetFromMultiIndex(dst_multi_id);
*(reinterpret_cast<vector_t*>(&p_dst[dst_index])) = dst_vec_data; *(reinterpret_cast<vector_t*>(&p_dst[dst_index])) = dst_vec_data;
}); });
...@@ -204,5 +204,21 @@ threadwise_tensor_slice_copy_generic(SrcDesc, ...@@ -204,5 +204,21 @@ threadwise_tensor_slice_copy_generic(SrcDesc,
SliceLengths, SliceLengths,
DimAccessOrder) DimAccessOrder)
{ {
// not implemented constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
constexpr auto slice_lengths_in_access_order =
SliceLengths{}.ReorderGivenNew2Old(DimAccessOrder{});
ford<decltype(slice_lengths_in_access_order)>{}([&](auto data_multi_id_in_access_order) {
const auto data_multi_id =
reorder_array_given_old2new(data_multi_id_in_access_order, DimAccessOrder{});
const index_t dst_index =
dst_desc.GetOffsetFromMultiIndex(src_multi_offset + data_multi_id);
const index_t src_index =
src_desc.GetOffsetFromMultiIndex(dst_multi_offset + data_multi_id);
p_dst[dst_index] = p_src[src_index];
});
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment