Commit 545d9305 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 37f4e2b6
...@@ -100,7 +100,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -100,7 +100,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
constexpr index_t E = C * Y * X; constexpr index_t E = C * Y * X;
// sanity-check for vectorized memory load // sanity-check for vectorized memory load
static_assert((Ho == 1 || ConvStrideW % InBlockCopySrcDataPerRead_B == 0) && static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1)) &&
(X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0), (X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will " "wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"); "be violated");
......
...@@ -100,7 +100,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -100,7 +100,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
constexpr index_t E = C * Y * X; constexpr index_t E = C * Y * X;
// sanity-check for vectorized memory load // sanity-check for vectorized memory load
static_assert((Ho == 1 || ConvStrideW % InBlockCopySrcDataPerRead_B == 0) && static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1)) &&
(X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0), (X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will " "wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"); "be violated");
......
...@@ -107,7 +107,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded ...@@ -107,7 +107,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded
constexpr index_t E = C * Y * X; constexpr index_t E = C * Y * X;
// sanity-check for vectorized memory load // sanity-check for vectorized memory load
static_assert((Ho == 1 || ConvStrideW % InBlockCopySrcDataPerRead_B == 0) && static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1)) &&
(X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0), (X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will " "wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"); "be violated");
......
...@@ -107,7 +107,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -107,7 +107,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
constexpr index_t E = C * Y * X; constexpr index_t E = C * Y * X;
// sanity-check for vectorized memory load // sanity-check for vectorized memory load
static_assert((Ho == 1 || ConvStrideW % InBlockCopySrcDataPerRead_B == 0) && static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1)) &&
(X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0), (X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will " "wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"); "be violated");
...@@ -174,9 +174,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -174,9 +174,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
decltype(in_e_n1_b_n2_global_desc), decltype(in_e_n1_b_n2_global_desc),
decltype(in_e_n1_b_n2_block_desc), decltype(in_e_n1_b_n2_block_desc),
Sequence<0, 1, 0, 1>, Sequence<0, 1, 0, 1>,
Sequence<1, 0, 1, 0>,
Sequence<1, 1, 1, 1>, Sequence<1, 1, 1, 1>,
Sequence<0, 0, 0, 0>,
decltype(in_e_n1_b_n2_block_desc.GetLengths()), decltype(in_e_n1_b_n2_block_desc.GetLengths()),
InBlockCopySubLengths_E_N1_B_N2, InBlockCopySubLengths_E_N1_B_N2,
InBlockCopyClusterLengths_E_N1_B_N2, InBlockCopyClusterLengths_E_N1_B_N2,
...@@ -219,9 +217,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -219,9 +217,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
Sequence<1, 1>, Sequence<1, 1>,
Sequence<0, 0>,
Sequence<1, 1>, Sequence<1, 1>,
Sequence<0, 0>,
decltype(wei_e_k_block_desc.GetLengths()), decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K, WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K, WeiBlockCopyClusterLengths_E_K,
...@@ -299,8 +295,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -299,8 +295,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
blockwise_in_copy.Run(p_in_global, p_in_block_double); blockwise_in_copy.template Run<Float, address_space_t::global, address_space_t::lds>(
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double); p_in_global, p_in_block_double);
blockwise_wei_copy.template Run<Float, address_space_t::global, address_space_t::lds>(
p_wei_global, p_wei_block_double);
} }
// LDS double buffer: main body // LDS double buffer: main body
...@@ -331,15 +329,19 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -331,15 +329,19 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.template RunLoadRegisterBuffer<Float, address_space_t::global>(
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global, p_wei_register_buffer); p_in_global, p_in_register_buffer);
blockwise_wei_copy.template RunLoadRegisterBuffer<Float, address_space_t::global>(
p_wei_global, p_wei_register_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next); blockwise_in_copy.template RunStoreRegisterBuffer<Float, address_space_t::lds>(
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next); p_in_register_buffer, p_in_block_next);
blockwise_wei_copy.template RunStoreRegisterBuffer<Float, address_space_t::lds>(
p_wei_register_buffer, p_wei_block_next);
} }
} }
...@@ -355,17 +357,19 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -355,17 +357,19 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.template RunLoadRegisterBuffer<Float, address_space_t::global>(
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global, p_wei_register_buffer); p_in_global, p_in_register_buffer);
blockwise_wei_copy.template RunLoadRegisterBuffer<Float, address_space_t::global>(
p_wei_global, p_wei_register_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, blockwise_in_copy.template RunStoreRegisterBuffer<Float, address_space_t::lds>(
p_in_block_double + in_block_space); p_in_register_buffer, p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, blockwise_wei_copy.template RunStoreRegisterBuffer<Float, address_space_t::lds>(
p_wei_block_double + wei_block_space); p_wei_register_buffer, p_wei_block_double + wei_block_space);
// odd iteration // odd iteration
__syncthreads(); __syncthreads();
...@@ -424,9 +428,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -424,9 +428,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(out_k0_k1_n1_b_n2_thread_desc), ThreadwiseGenericTensorSliceCopy_v4r2<decltype(out_k0_k1_n1_b_n2_thread_desc),
decltype(out_k0_k1_n1_b_n2_global_desc), decltype(out_k0_k1_n1_b_n2_global_desc),
Sequence<1, 1, 1, 1, 1>, Sequence<1, 1, 1, 1, 1>,
Sequence<0, 0, 0, 0, 0>,
Sequence<1, 1, 1, 0, 1>, Sequence<1, 1, 1, 0, 1>,
Sequence<0, 0, 0, 1, 0>,
decltype( decltype(
out_k0_k1_n1_b_n2_thread_desc.GetLengths()), out_k0_k1_n1_b_n2_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 5, 1>::type, arithmetic_sequence_gen<0, 5, 1>::type,
......
...@@ -84,7 +84,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -84,7 +84,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
constexpr index_t B = N * Ho * Wo; constexpr index_t B = N * Ho * Wo;
// sanity-check for vectorized memory load // sanity-check for vectorized memory load
static_assert((Ho == 1 || ConvStrideW % InBlockCopyDataPerAccess_B == 0) && static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopyDataPerAccess_B == 1)) &&
(X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0), (X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will " "wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"); "be violated");
......
...@@ -84,7 +84,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -84,7 +84,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
constexpr index_t B = N * Ho * Wo; constexpr index_t B = N * Ho * Wo;
// sanity-check for vectorized memory load // sanity-check for vectorized memory load
static_assert((Ho == 1 || ConvStrideW % InBlockCopyDataPerAccess_B == 0) && static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopyDataPerAccess_B == 1)) &&
(X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0), (X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will " "wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"); "be violated");
......
...@@ -91,7 +91,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded ...@@ -91,7 +91,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded
constexpr index_t B = N * Ho * Wo; constexpr index_t B = N * Ho * Wo;
// sanity-check for vectorized memory load // sanity-check for vectorized memory load
static_assert((Ho == 1 || ConvStrideW % InBlockCopyDataPerAccess_B == 0) && static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopyDataPerAccess_B == 1)) &&
(X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0), (X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will " "wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"); "be violated");
......
...@@ -90,7 +90,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -90,7 +90,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
constexpr index_t B = N * Ho * Wo; constexpr index_t B = N * Ho * Wo;
// sanity-check for vectorized memory load // sanity-check for vectorized memory load
static_assert((Ho == 1 || ConvStrideW % InBlockCopyDataPerAccess_B == 0) && static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopyDataPerAccess_B == 1)) &&
(X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0), (X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will " "wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"); "be violated");
...@@ -145,6 +145,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -145,6 +145,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
BlockwiseGenericTensorSliceCopy_v4<BlockSize, BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(in_e_b_global_desc), decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc), decltype(in_e_b_block_desc),
Sequence<0, 0>,
Sequence<1, 1>,
decltype(in_e_b_block_desc.GetLengths()), decltype(in_e_b_block_desc.GetLengths()),
InBlockCopySubLengths_E_B, InBlockCopySubLengths_E_B,
InBlockCopyClusterLengths_E_B, InBlockCopyClusterLengths_E_B,
...@@ -157,13 +159,21 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -157,13 +159,21 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
InBlockCopyDataPerAccess_B>( InBlockCopyDataPerAccess_B>(
{0, b_block_data_on_global}, {0, 0}); {0, b_block_data_on_global}, {0, 0});
// weight tensor // weight tensor
// global mem // global mem
#if 0
constexpr auto wei_e_k_global_desc = constexpr auto wei_e_k_global_desc =
transform_tensor_descriptor(wei_k_c_y_x_global_desc, transform_tensor_descriptor(wei_k_c_y_x_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, PassThrough<K>{}), make_tuple(Merge<Sequence<C, Y, X>>{}, PassThrough<K>{}),
make_tuple(Sequence<1, 2, 3>{}, Sequence<0>{}), make_tuple(Sequence<1, 2, 3>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
#else // hack
constexpr auto wei_e_k_global_desc_old =
WeiGlobalDesc::Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
constexpr auto wei_e_k_global_desc = make_native_tensor_descriptor(
wei_e_k_global_desc_old.GetLengths(), wei_e_k_global_desc_old.GetStrides());
#endif
// LDS // LDS
// be careful of LDS alignment // be careful of LDS alignment
...@@ -176,6 +186,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -176,6 +186,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
BlockwiseGenericTensorSliceCopy_v4<BlockSize, BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
Sequence<1, 1>,
Sequence<1, 1>,
decltype(wei_e_k_block_desc.GetLengths()), decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K, WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K, WeiBlockCopyClusterLengths_E_K,
...@@ -253,8 +265,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -253,8 +265,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
blockwise_in_copy.Run(p_in_global, p_in_block_double); blockwise_in_copy.template Run<Float, address_space_t::global, address_space_t::lds>(
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double); p_in_global, p_in_block_double);
blockwise_wei_copy.template Run<Float, address_space_t::global, address_space_t::lds>(
p_wei_global, p_wei_block_double);
} }
// LDS double buffer: main body // LDS double buffer: main body
...@@ -285,15 +299,19 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -285,15 +299,19 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.template RunLoadRegisterBuffer<Float, address_space_t::global>(
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global, p_wei_register_buffer); p_in_global, p_in_register_buffer);
blockwise_wei_copy.template RunLoadRegisterBuffer<Float, address_space_t::global>(
p_wei_global, p_wei_register_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next); blockwise_in_copy.template RunStoreRegisterBuffer<Float, address_space_t::lds>(
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next); p_in_register_buffer, p_in_block_next);
blockwise_wei_copy.template RunStoreRegisterBuffer<Float, address_space_t::lds>(
p_wei_register_buffer, p_wei_block_next);
} }
} }
...@@ -309,17 +327,19 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -309,17 +327,19 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); blockwise_in_copy.template RunLoadRegisterBuffer<Float, address_space_t::global>(
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global, p_wei_register_buffer); p_in_global, p_in_register_buffer);
blockwise_wei_copy.template RunLoadRegisterBuffer<Float, address_space_t::global>(
p_wei_global, p_wei_register_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, blockwise_in_copy.template RunStoreRegisterBuffer<Float, address_space_t::lds>(
p_in_block_double + in_block_space); p_in_register_buffer, p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, blockwise_wei_copy.template RunStoreRegisterBuffer<Float, address_space_t::lds>(
p_wei_block_double + wei_block_space); p_wei_register_buffer, p_wei_block_double + wei_block_space);
// odd iteration // odd iteration
__syncthreads(); __syncthreads();
...@@ -367,9 +387,11 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -367,9 +387,11 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// output threadwise copy // output threadwise copy
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v4r2< ThreadwiseGenericTensorSliceCopy_v4r2<
decltype(out_k0_k1_b0_b1_thread_desc), decltype(out_k0_k1_b0_b1_thread_desc),
decltype(out_k0_k1_b0_b1_global_desc), decltype(out_k0_k1_b0_b1_global_desc),
Sequence<1, 1, 1, 1>,
Sequence<1, 1, 0, 0>,
decltype(out_k0_k1_b0_b1_thread_desc.GetLengths()), decltype(out_k0_k1_b0_b1_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 4, 1>::type, arithmetic_sequence_gen<0, 4, 1>::type,
3, 3,
...@@ -378,9 +400,15 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf ...@@ -378,9 +400,15 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
{k_thread_data_on_global / K1, {k_thread_data_on_global / K1,
k_thread_data_on_global % K1, k_thread_data_on_global % K1,
b_thread_data_on_global / B1, b_thread_data_on_global / B1,
b_thread_data_on_global % B1}); b_thread_data_on_global % B1})
#if 1
threadwise_out_copy.Run(p_out_thread, p_out_global); .template Run_generic<Float, address_space_t::generic, address_space_t::global>
#elif 1
.template Run_optimized_dst_address_calculation<Float,
address_space_t::vgpr,
address_space_t::global>
#endif
(p_out_thread, p_out_global);
} }
} }
}; };
......
...@@ -96,13 +96,12 @@ struct ConstantTensorDescriptor ...@@ -96,13 +96,12 @@ struct ConstantTensorDescriptor
__host__ __device__ static constexpr auto GetElementSize() __host__ __device__ static constexpr auto GetElementSize()
{ {
return Number<accumulate_on_sequence( return Number<reduce_on_sequence(Lengths{}, math::multiplies<index_t>{}, Number<1>{})>{};
Lengths{}, math::multiplies<index_t>{}, Number<1>{})>{};
} }
__host__ __device__ static constexpr auto GetElementSpace() __host__ __device__ static constexpr auto GetElementSpace()
{ {
constexpr index_t element_space_unaligned = accumulate_on_sequence( constexpr index_t element_space_unaligned = reduce_on_sequence(
(GetLengths() - Number<1>{}) * GetStrides(), math::plus<index_t>{}, Number<1>{}); (GetLengths() - Number<1>{}) * GetStrides(), math::plus<index_t>{}, Number<1>{});
return Number<element_space_unaligned>{}; return Number<element_space_unaligned>{};
...@@ -155,7 +154,7 @@ struct ConstantTensorDescriptor ...@@ -155,7 +154,7 @@ struct ConstantTensorDescriptor
constexpr auto multi_id = Sequence<Is...>{}; constexpr auto multi_id = Sequence<Is...>{};
return Number<accumulate_on_sequence( return Number<reduce_on_sequence(
multi_id * GetStrides(), math::plus<index_t>{}, Number<0>{})>{}; multi_id * GetStrides(), math::plus<index_t>{}, Number<0>{})>{};
} }
...@@ -389,7 +388,7 @@ struct ConstantTensorDescriptor ...@@ -389,7 +388,7 @@ struct ConstantTensorDescriptor
constexpr auto fold_intervals = Sequence<FoldIntervals...>{}; constexpr auto fold_intervals = Sequence<FoldIntervals...>{};
constexpr index_t fold_intervals_product = constexpr index_t fold_intervals_product =
accumulate_on_sequence(fold_intervals, math::multiplies<index_t>{}, Number<1>{}); reduce_on_sequence(fold_intervals, math::multiplies<index_t>{}, Number<1>{});
constexpr auto unfold_length = GetLength(Number<IDim>{}); constexpr auto unfold_length = GetLength(Number<IDim>{});
constexpr auto unfold_stride = GetStride(Number<IDim>{}); constexpr auto unfold_stride = GetStride(Number<IDim>{});
...@@ -447,7 +446,7 @@ struct ConstantTensorDescriptor ...@@ -447,7 +446,7 @@ struct ConstantTensorDescriptor
static_assert(Type::Extract(middle).AreDimensionsContinuous(), "wrong! not unfoldable"); static_assert(Type::Extract(middle).AreDimensionsContinuous(), "wrong! not unfoldable");
// unfolded length, stride // unfolded length, stride
constexpr index_t unfold_length = accumulate_on_sequence( constexpr index_t unfold_length = reduce_on_sequence(
GetLengths().Extract(middle), math::multiplies<index_t>{}, Number<1>{}); GetLengths().Extract(middle), math::multiplies<index_t>{}, Number<1>{});
constexpr index_t unfold_stride = GetStride(Number<LastUnfoldDim>{}); constexpr index_t unfold_stride = GetStride(Number<LastUnfoldDim>{});
......
...@@ -41,11 +41,10 @@ struct PassThrough ...@@ -41,11 +41,10 @@ struct PassThrough
__host__ __device__ static constexpr bool IsLinearTransform() { return true; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
// TODO: should this function be here? should it be specific for padding check?
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
IsUpperIndexInPaddingArea(const UpperIndex& /* idx_up */) IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
{ {
return false; return true;
} }
}; };
...@@ -82,24 +81,39 @@ struct Pad ...@@ -82,24 +81,39 @@ struct Pad
__host__ __device__ static constexpr bool IsLinearTransform() { return true; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
// TODO: should this function be here? should it be specific for padding check? __host__ __device__ constexpr bool
__host__ __device__ constexpr bool IsUpperIndexInPaddingArea(const UpperIndex& idx_up) const IsUpperIndexMappedToValidLowerIndex(const UpperIndex& idx_up) const
{ {
bool flag = false; #if 0
struct lambda_no_pad
{
__host__ __device__ constexpr bool operator()(index_t x) const { return x == 0; }
};
if(sequence_all_of(LeftPads{}, lambda_no_pad{}) &&
sequence_all_of(RightPads{}, lambda_no_pad{}))
{
return true;
}
else
#endif
{
bool flag = true;
static_for<0, nDim, 1>{}([&](auto idim) { static_for<0, nDim, 1>{}([&](auto idim) {
// only check if there is left-padding // only check if there is left-padding
static_if<(LeftPads::At(idim) != 0)>{}( static_if<(LeftPads::At(idim) != 0)>{}(
[&](auto) { flag = flag || idx_up[idim] < LeftPads::At(idim); }); [&](auto) { flag = flag && idx_up[idim] >= LeftPads::At(idim); });
// only check if there is right-padding // only check if there is right-padding
static_if<(RightPads::At(idim) != 0)>{}([&](auto) { static_if<(RightPads::At(idim) != 0)>{}([&](auto) {
flag = flag || idx_up[idim] >= LeftPads::At(idim) + LowerLengths::At(idim); flag = flag && (idx_up[idim] < LeftPads::At(idim) + LowerLengths::At(idim));
}); });
}); });
return flag; return flag;
} }
}
}; };
// LowerLengths: Sequence<...> // LowerLengths: Sequence<...>
...@@ -155,16 +169,10 @@ struct Merge ...@@ -155,16 +169,10 @@ struct Merge
LowerLengths::PopFront(), math::multiplies<index_t>{}, Number<1>{}) LowerLengths::PopFront(), math::multiplies<index_t>{}, Number<1>{})
.PushBack(Number<1>{}); .PushBack(Number<1>{});
#if 1 // would these 2 versions be compiled to same ISA?
// calculate index in each of the dimensions in the order of their dimension
static_for<0, nDimLow - 1, 1>{}( static_for<0, nDimLow - 1, 1>{}(
lambda_CalculateLowerIndex<decltype(pseudo_low_strides)>(itmp, idx_low)); lambda_CalculateLowerIndex<decltype(pseudo_low_strides)>(itmp, idx_low));
idx_low(nDimLow - 1) = itmp / pseudo_low_strides[nDimLow - 1]; idx_low(nDimLow - 1) = itmp / pseudo_low_strides[nDimLow - 1];
#else
static_for<0, nDimLow, 1>{}(
lambda_CalculateLowerIndex<decltype(pseudo_low_strides)>(itmp, idx_low));
#endif
return idx_low; return idx_low;
} }
...@@ -244,6 +252,7 @@ struct Merge ...@@ -244,6 +252,7 @@ struct Merge
}); });
// highest dimension, no out-of-bound check // highest dimension, no out-of-bound check
if(borrow) if(borrow)
{ {
--idx_low_new(0); --idx_low_new(0);
...@@ -255,11 +264,10 @@ struct Merge ...@@ -255,11 +264,10 @@ struct Merge
__host__ __device__ static constexpr bool IsLinearTransform() { return false; } __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
// TODO: should this function be here? should it be specific for padding check?
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
IsUpperIndexInPaddingArea(const UpperIndex& /* idx_up */) IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
{ {
return false; return true;
} }
}; };
...@@ -304,11 +312,10 @@ struct Unmerge ...@@ -304,11 +312,10 @@ struct Unmerge
__host__ __device__ static constexpr bool IsLinearTransform() { return true; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
// TODO: should this function be here? should it be specific for padding check?
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
IsUpperIndexInPaddingArea(const UpperIndex& /* idx_up */) IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
{ {
return false; return true;
} }
}; };
...@@ -362,9 +369,9 @@ struct Embed ...@@ -362,9 +369,9 @@ struct Embed
__host__ __device__ static constexpr bool IsLinearTransform() { return true; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
IsUpperIndexInPaddingArea(const UpperIndex& /* idx_up */) IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
{ {
return false; return true;
} }
}; };
...@@ -404,11 +411,10 @@ struct Vectorize ...@@ -404,11 +411,10 @@ struct Vectorize
__host__ __device__ static constexpr bool IsLinearTransform() { return true; } __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
// TODO: should this function be here? should it be specific for padding check?
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
IsUpperIndexInPaddingArea(const UpperIndex& /* idx_up */) IsUpperIndexMappedToValidLowerIndex(const UpperIndex& /* idx_up */)
{ {
return false; return true;
} }
}; };
......
#ifndef CK_TENSOR_COORDINATE_HELPER_HPP
#define CK_TENSOR_COORDINATE_HELPER_HPP
#include "tensor_coordiante_v2.hpp"
namespace ck {
template <typename TensorDesc>
__host__ __device__ constexpr auto
make_tensor_coordinate_v2(TensorDesc, MultiIndex<TensorDesc::GetNumOfDimension()> idx)
{
return typename TensorCoordinate_v2<TensorDesc>::type(idx);
}
} // namespace ck
#endif
...@@ -76,8 +76,7 @@ struct NativeTensorCoordinate ...@@ -76,8 +76,7 @@ struct NativeTensorCoordinate
return coord; return coord;
} }
// TODO: should this function be here? should it be specific for padding check? __host__ __device__ static constexpr bool IsUpperIndexMappedToValidOffset() { return true; }
__host__ __device__ static constexpr bool IsAnyLevelIndexInPaddingArea() { return false; }
private: private:
// mIndex may be saved and update, however, the value of some (or all) of its entries may // mIndex may be saved and update, however, the value of some (or all) of its entries may
...@@ -166,11 +165,11 @@ struct TransformedTensorCoordinate ...@@ -166,11 +165,11 @@ struct TransformedTensorCoordinate
return coord_up; return coord_up;
} }
// TODO: should this function be here? should it be specific for padding check? // this function should be inexpensive, because there is no upper-to-lower index transformation
__host__ __device__ constexpr bool IsAnyLevelIndexInPaddingArea() const __host__ __device__ constexpr bool IsUpperIndexMappedToValidOffset() const
{ {
return tensor_desc_type::IsUpperIndexInPaddingArea(GetIndex()) || return tensor_desc_type::IsUpperIndexMappedToValidLowerIndex(GetIndex()) &&
mCoordLow.IsAnyLevelIndexInPaddingArea(); mCoordLow.IsUpperIndexMappedToValidOffset();
} }
private: private:
...@@ -206,11 +205,5 @@ struct TensorCoordinate_v2 ...@@ -206,11 +205,5 @@ struct TensorCoordinate_v2
using type = decltype(MakeDummyTensorCoordinate(TensorDesc{})); using type = decltype(MakeDummyTensorCoordinate(TensorDesc{}));
}; };
template <typename TensorDesc> } // namespace ck
__host__ __device__ constexpr auto
make_tensor_coordinate_v2(TensorDesc, MultiIndex<TensorDesc::GetNumOfDimension()> idx)
{
return typename TensorCoordinate_v2<TensorDesc>::type(idx);
}
}
#endif #endif
...@@ -66,12 +66,12 @@ struct NativeTensorDescriptor ...@@ -66,12 +66,12 @@ struct NativeTensorDescriptor
__host__ __device__ static constexpr index_t GetElementSize() __host__ __device__ static constexpr index_t GetElementSize()
{ {
return accumulate_on_sequence(GetLengths(), math::multiplies<index_t>{}, Number<1>{}); return reduce_on_sequence(GetLengths(), math::multiplies<index_t>{}, Number<1>{});
} }
__host__ __device__ static constexpr index_t GetElementSpace() __host__ __device__ static constexpr index_t GetElementSpace()
{ {
return accumulate_on_sequence( return reduce_on_sequence(
(GetLengths() - Number<1>{}) * GetStrides(), math::plus<index_t>{}, Number<1>{}); (GetLengths() - Number<1>{}) * GetStrides(), math::plus<index_t>{}, Number<1>{});
} }
...@@ -120,10 +120,10 @@ struct NativeTensorDescriptor ...@@ -120,10 +120,10 @@ struct NativeTensorDescriptor
} }
#endif #endif
// TODO: should this function be here? should it be specific for padding check? __host__ __device__ static constexpr bool
__host__ __device__ static constexpr bool IsUpperIndexInPaddingArea(const Index& /* idx */) IsUpperIndexMappedToValidOffset(const Index& /* idx */)
{ {
return false; return true;
} }
}; };
...@@ -290,7 +290,7 @@ struct TransformedTensorDescriptor ...@@ -290,7 +290,7 @@ struct TransformedTensorDescriptor
__host__ __device__ static constexpr index_t GetElementSize() __host__ __device__ static constexpr index_t GetElementSize()
{ {
return accumulate_on_sequence(GetLengths(), math::multiplies<index_t>{}, Number<1>{}); return reduce_on_sequence(GetLengths(), math::multiplies<index_t>{}, Number<1>{});
} }
__host__ __device__ static constexpr index_t GetElementSpace() __host__ __device__ static constexpr index_t GetElementSpace()
...@@ -375,7 +375,7 @@ struct TransformedTensorDescriptor ...@@ -375,7 +375,7 @@ struct TransformedTensorDescriptor
constexpr bool is_linear_transform = tran.IsLinearTransform(); constexpr bool is_linear_transform = tran.IsLinearTransform();
// judge if all lower dimension are linear // judge if all lower dimension are linear
constexpr bool is_all_low_dim_linear = math::accumulate_on_sequence( constexpr bool is_all_low_dim_linear = math::reduce_on_sequence(
pick_sequence_elements_by_mask( pick_sequence_elements_by_mask(
GetLowerTensorDescriptor().GetMaskOfLinearDimensions(), LowDimensionId{}), GetLowerTensorDescriptor().GetMaskOfLinearDimensions(), LowDimensionId{}),
math::logic_and<bool>{}, math::logic_and<bool>{},
...@@ -441,21 +441,32 @@ struct TransformedTensorDescriptor ...@@ -441,21 +441,32 @@ struct TransformedTensorDescriptor
} }
#endif #endif
// TODO: should this function be here? should it be specific for padding check? __host__ __device__ static constexpr bool
__host__ __device__ static constexpr bool IsUpperIndexInPaddingArea(const UpperIndex& idx_up) IsUpperIndexMappedToValidLowerIndex(const UpperIndex& idx_up)
{ {
bool flag = false; bool flag = true;
static_for<0, nTransform, 1>{}([&](auto itran) { static_for<0, nTransform, 1>{}([&](auto itran) {
constexpr auto tran = Transforms{}.At(itran); constexpr auto tran = Transforms{}.At(itran);
const auto idx_up_part = pick_array_element(idx_up, UpDimensionIds{}.At(itran)); const auto idx_up_part = pick_array_element(idx_up, UpDimensionIds{}.At(itran));
flag = flag || tran.IsUpperIndexInPaddingArea(to_array(idx_up_part)); flag = flag && tran.IsUpperIndexMappedToValidLowerIndex(to_array(idx_up_part));
}); });
return flag; return flag;
} }
// Whenever this function is called, it will call CalculateLowerIndex() recursively
// If you have created a tensor coordinate already, instead of calling this function,
// you should call TransformedTensorCoordinate::IsUpperIndexMappedToValidOffset()
__host__ __device__ static constexpr bool
IsUpperIndexMappedToValidOffset(const UpperIndex& idx_up)
{
return IsUpperIndexMappedToValidLowerIndex(idx_up) &&
GetLowerTensorDescriptor().IsUpperIndexMappedToValidOffset(
CalculateLowerIndex(idx_up));
}
}; };
} // namespace ck } // namespace ck
......
...@@ -162,7 +162,7 @@ struct Blockwise3dTensorCopy3 ...@@ -162,7 +162,7 @@ struct Blockwise3dTensorCopy3
"wrrong! BlockSize is not big enough for ThreadPerDims!"); "wrrong! BlockSize is not big enough for ThreadPerDims!");
constexpr index_t num_active_thread = constexpr index_t num_active_thread =
accumulate_on_sequence(ThreadPerDims{}, math::multiplies<index_t>{}, Number<1>{}); reduce_on_sequence(ThreadPerDims{}, math::multiplies<index_t>{}, Number<1>{});
if(BlockSize > num_active_thread) if(BlockSize > num_active_thread)
{ {
......
...@@ -505,7 +505,7 @@ struct Blockwise4dTensorCopy3 ...@@ -505,7 +505,7 @@ struct Blockwise4dTensorCopy3
"wrrong! BlockSize is not big enough for ThreadPerDims!"); "wrrong! BlockSize is not big enough for ThreadPerDims!");
constexpr index_t num_active_thread = constexpr index_t num_active_thread =
accumulate_on_sequence(ThreadPerDims{}, math::multiplies<index_t>{}, Number<1>{}); reduce_on_sequence(ThreadPerDims{}, math::multiplies<index_t>{}, Number<1>{});
if(BlockSize > num_active_thread) if(BlockSize > num_active_thread)
{ {
......
...@@ -681,9 +681,7 @@ template <index_t BlockSize, ...@@ -681,9 +681,7 @@ template <index_t BlockSize,
typename SrcDesc, typename SrcDesc,
typename DstDesc, typename DstDesc,
typename SrcLinearDimensionMask, typename SrcLinearDimensionMask,
typename SrcNonLinearDimensionMask,
typename DstLinearDimensionMask, typename DstLinearDimensionMask,
typename DstNonLinearDimensionMask,
typename SliceLengths, typename SliceLengths,
typename SubLengths, typename SubLengths,
typename ThreadClusterLengths, typename ThreadClusterLengths,
...@@ -738,45 +736,43 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -738,45 +736,43 @@ struct BlockwiseGenericTensorSliceCopy_v4
return RegisterBufferDesc::GetElementSpace(); return RegisterBufferDesc::GetElementSpace();
} }
template <typename TData> template <typename TData, address_space_t SrcAddressSpace = address_space_t::generic>
__device__ void RunLoadRegisterBuffer(const TData* p_src, TData* p_buffer) const __device__ void RunLoadRegisterBuffer(const TData* p_src, TData* p_buffer) const
{ {
#if 0 #if 1
mThreadwiseLoad.Run_generic(p_src, p_buffer); mThreadwiseLoad.template Run_generic<TData, SrcAddressSpace, address_space_t::vgpr>(
#elif 1 p_src, p_buffer);
// hardcoded: src is global memory #else
mThreadwiseLoad.template Run_generic<TData, address_space_t::global>(p_src, p_buffer); mThreadwiseLoad.template Run_optimized_src_address_calculation<TData,
#elif 1 SrcAddressSpace,
// hardcoded: src is global memory address_space_t::vgpr>(
mThreadwiseLoad
.template Run_optimized_src_address_calculation<TData, address_space_t::global>(
p_src, p_buffer); p_src, p_buffer);
#endif #endif
} }
template <typename TData> template <typename TData, address_space_t DstAddressSpace = address_space_t::generic>
__device__ void RunStoreRegisterBuffer(const TData* p_buffer, TData* p_dst) const __device__ void RunStoreRegisterBuffer(const TData* p_buffer, TData* p_dst) const
{ {
#if 0 #if 1
mThreadwiseStore.Run_generic(p_buffer, p_dst); mThreadwiseStore.template Run_generic<TData, address_space_t::vgpr, DstAddressSpace>(
#elif 1 p_buffer, p_dst);
// hardcoded: dst is lds #else
mThreadwiseStore.template Run_generic<TData, address_space_t::lds>(p_buffer, p_dst); mThreadwiseStore.template Run_optimized_dst_address_calculation<TData,
#elif 1 address_space_t::vgpr,
// hardcoded: dst is lds DstAddressSpace>(p_buffer,
mThreadwiseStore
.template Run_optimized_dst_address_calculation<TData, address_space_t::lds>(p_buffer,
p_dst); p_dst);
#endif #endif
} }
template <typename TData> template <typename TData,
address_space_t SrcAddressSpace = address_space_t::generic,
address_space_t DstAddressSpace = address_space_t::generic>
__device__ void Run(const TData* p_src, TData* p_dst) const __device__ void Run(const TData* p_src, TData* p_dst) const
{ {
TData p_buffer[GetRegisterBufferSize()]; TData p_buffer[GetRegisterBufferSize()];
RunLoadRegisterBuffer(p_src, p_buffer); RunLoadRegisterBuffer<TData, SrcAddressSpace>(p_src, p_buffer);
RunStoreRegisterBuffer(p_buffer, p_dst); RunStoreRegisterBuffer<TData, DstAddressSpace>(p_buffer, p_dst);
} }
template <typename T, bool PositiveDirection> template <typename T, bool PositiveDirection>
...@@ -802,9 +798,7 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -802,9 +798,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
ThreadwiseGenericTensorSliceCopy_v4r2<SrcDesc, ThreadwiseGenericTensorSliceCopy_v4r2<SrcDesc,
RegisterBufferDesc, RegisterBufferDesc,
SrcLinearDimensionMask, SrcLinearDimensionMask,
SrcNonLinearDimensionMask,
typename uniform_sequence_gen<nDim, 1>::type, typename uniform_sequence_gen<nDim, 1>::type,
typename uniform_sequence_gen<nDim, 0>::type,
SubLengths, SubLengths,
SrcDimAccessOrder, SrcDimAccessOrder,
SrcVectorAccessDim, SrcVectorAccessDim,
...@@ -815,9 +809,7 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -815,9 +809,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
ThreadwiseGenericTensorSliceCopy_v4r2<RegisterBufferDesc, ThreadwiseGenericTensorSliceCopy_v4r2<RegisterBufferDesc,
DstDesc, DstDesc,
typename uniform_sequence_gen<nDim, 1>::type, typename uniform_sequence_gen<nDim, 1>::type,
typename uniform_sequence_gen<nDim, 0>::type,
DstLinearDimensionMask, DstLinearDimensionMask,
DstNonLinearDimensionMask,
SubLengths, SubLengths,
DstDimAccessOrder, DstDimAccessOrder,
DstVectorAccessDim, DstVectorAccessDim,
......
...@@ -1131,9 +1131,7 @@ struct ThreadwiseGenericTensorSliceCopy_v3r1 ...@@ -1131,9 +1131,7 @@ struct ThreadwiseGenericTensorSliceCopy_v3r1
template <typename SrcDesc, template <typename SrcDesc,
typename DstDesc, typename DstDesc,
typename SrcLinearDimensionMask, typename SrcLinearDimensionMask,
typename SrcNonLinearDimensionMask,
typename DstLinearDimensionMask, typename DstLinearDimensionMask,
typename DstNonLinearDimensionMask,
typename SliceLengths, typename SliceLengths,
typename DimAccessOrder, typename DimAccessOrder,
index_t VectorAccessDim, index_t VectorAccessDim,
...@@ -1231,8 +1229,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -1231,8 +1229,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// Check src vector's padding situation, only check the first data in this src // Check src vector's padding situation, only check the first data in this src
// vector. It's user's responsiblity to make sure all data in the src vector has // vector. It's user's responsiblity to make sure all data in the src vector has
// the same padding situation // the same padding situation
// TODO: not sure a dedicated IsAnyLevelIndexInPaddingArea() function is neccessary if(src_coord.IsUpperIndexMappedToValidOffset())
if(!src_coord.IsAnyLevelIndexInPaddingArea())
{ {
static_if<SrcAddressSpace == address_space_t::global>{}([&](auto) { static_if<SrcAddressSpace == address_space_t::global>{}([&](auto) {
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE #if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
...@@ -1260,13 +1257,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -1260,13 +1257,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const auto dst_coord = mDstSliceOrigin + (long_vector_data_begin_id + scalar_id); const auto dst_coord = mDstSliceOrigin + (long_vector_data_begin_id + scalar_id);
// Check dst vector's padding situation, only check the first data in this dst // Check dst vector's padding situation, only check the first data in this dst
// vector. It's user's responsiblity to make sure all data in the dst vector has // vector. It's user's responsiblity to make sure all data in the dst vector has
// the same padding situation // the same padding situation
// TODO: not sure a dedicated IsAnyLevelIndexInPaddingArea() function is neccessary if(dst_coord.IsUpperIndexMappedToValidOffset())
#if 0 // tuning
if(!dst_coord.IsAnyLevelIndexInPaddingArea())
#endif
{ {
static_if<DstAddressSpace == address_space_t::global>{}([&](auto) { static_if<DstAddressSpace == address_space_t::global>{}([&](auto) {
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE #if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
...@@ -1303,7 +1297,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -1303,7 +1297,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// Will do padding check on src data: Read 0 if src data is in padding area. // Will do padding check on src data: Read 0 if src data is in padding area.
// Will do padding check on dst data: No write if dst data is in paddin area. // Will do padding check on dst data: No write if dst data is in paddin area.
// This version is optimized for address calculation of src tensor // This version is optimized for address calculation of src tensor
template <typename TData, address_space_t SrcAddressSpace = address_space_t::generic> template <typename TData,
address_space_t SrcAddressSpace = address_space_t::generic,
address_space_t DstAddressSpace = address_space_t::generic>
__device__ void Run_optimized_src_address_calculation(const TData* p_src, TData* p_dst) const __device__ void Run_optimized_src_address_calculation(const TData* p_src, TData* p_dst) const
{ {
using src_vector_t = typename vector_type<TData, SrcDataPerAccess>::MemoryType; using src_vector_t = typename vector_type<TData, SrcDataPerAccess>::MemoryType;
...@@ -1322,7 +1318,8 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -1322,7 +1318,8 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// TODO:: stop using this hack, once TransformedTensorDescriptor::GetLinearDimensionMask() // TODO:: stop using this hack, once TransformedTensorDescriptor::GetLinearDimensionMask()
// is implemented // is implemented
constexpr auto src_linear_dim_mask = SrcLinearDimensionMask{}; constexpr auto src_linear_dim_mask = SrcLinearDimensionMask{};
constexpr auto src_nonlinear_dim_mask = SrcNonLinearDimensionMask{}; constexpr auto src_nonlinear_dim_mask =
SrcLinearDimensionMask::Transform(logical_not<index_t>{});
static_assert( static_assert(
src_linear_dim_mask.At(VectorAccessDim) || long_vector_size == SrcDataPerAccess, src_linear_dim_mask.At(VectorAccessDim) || long_vector_size == SrcDataPerAccess,
...@@ -1392,9 +1389,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -1392,9 +1389,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// Check src vector's padding situation, only check the first data in // Check src vector's padding situation, only check the first data in
// this src vector. It's user's responsiblity to make sure all data in // this src vector. It's user's responsiblity to make sure all data in
// the src vector has the same padding situation // the src vector has the same padding situation
// TODO: not sure a dedicated IsAnyLevelIndexInPaddingArea() function is if(src_coord.IsUpperIndexMappedToValidOffset())
// neccessary
if(!src_coord.IsAnyLevelIndexInPaddingArea())
{ {
static_if<SrcAddressSpace == address_space_t::global>{}([&](auto) { static_if<SrcAddressSpace == address_space_t::global>{}([&](auto) {
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE #if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
...@@ -1427,14 +1422,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -1427,14 +1422,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const auto dst_coord = mDstSliceOrigin + (nonlinear_dim_data_steps + const auto dst_coord = mDstSliceOrigin + (nonlinear_dim_data_steps +
linear_dim_data_steps + scalar_id); linear_dim_data_steps + scalar_id);
// Check dst vector's padding situation, only check the first data in // Check dst vector's padding situation, only check the first data in
// this dst vector. It's user's responsiblity to make sure all data in // this dst vector. It's user's responsiblity to make sure all data in
// the dst vector has the same padding situation // the dst vector has the same padding situation
// TODO: not sure a dedicated IsAnyLevelIndexInPaddingArea() function is if(dst_coord.IsUpperIndexMappedToValidOffset())
// neccessary
#if 0 // tuning
if(!dst_coord.IsAnyLevelIndexInPaddingArea())
#endif
{ {
*reinterpret_cast<dst_vector_t*>(&p_dst[dst_coord.GetOffset()]) = *reinterpret_cast<dst_vector_t*>(&p_dst[dst_coord.GetOffset()]) =
*reinterpret_cast<dst_vector_t*>(&p_long_vector[buffer_offset]); *reinterpret_cast<dst_vector_t*>(&p_long_vector[buffer_offset]);
...@@ -1450,7 +1441,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -1450,7 +1441,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// Will do padding check on src data: Read 0 if src data is in padding area. // Will do padding check on src data: Read 0 if src data is in padding area.
// Will do padding check on dst data: No write if dst data is in paddin area. // Will do padding check on dst data: No write if dst data is in paddin area.
// This version is optimized for address calculation of dst tensor // This version is optimized for address calculation of dst tensor
template <typename TData, address_space_t DstAddressSpace = address_space_t::generic> template <typename TData,
address_space_t SrcAddressSpace = address_space_t::generic,
address_space_t DstAddressSpace = address_space_t::generic>
__device__ void Run_optimized_dst_address_calculation(const TData* p_src, TData* p_dst) const __device__ void Run_optimized_dst_address_calculation(const TData* p_src, TData* p_dst) const
{ {
using src_vector_t = typename vector_type<TData, SrcDataPerAccess>::MemoryType; using src_vector_t = typename vector_type<TData, SrcDataPerAccess>::MemoryType;
...@@ -1469,7 +1462,8 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -1469,7 +1462,8 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// TODO:: stop using this hack, once TransformedTensorDescriptor::GetLinearDimensionMask() // TODO:: stop using this hack, once TransformedTensorDescriptor::GetLinearDimensionMask()
// is implemented // is implemented
constexpr auto dst_linear_dim_mask = DstLinearDimensionMask{}; constexpr auto dst_linear_dim_mask = DstLinearDimensionMask{};
constexpr auto dst_nonlinear_dim_mask = DstNonLinearDimensionMask{}; constexpr auto dst_nonlinear_dim_mask =
DstLinearDimensionMask::Transform(logical_not<index_t>{});
static_assert( static_assert(
dst_linear_dim_mask.At(VectorAccessDim) || long_vector_size == DstDataPerAccess, dst_linear_dim_mask.At(VectorAccessDim) || long_vector_size == DstDataPerAccess,
...@@ -1535,9 +1529,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -1535,9 +1529,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// Check src vector's padding situation, only check the first data in // Check src vector's padding situation, only check the first data in
// this src vector. It's user's responsiblity to make sure all data in // this src vector. It's user's responsiblity to make sure all data in
// the src vector has the same padding situation // the src vector has the same padding situation
// TODO: not sure a dedicated IsAnyLevelIndexInPaddingArea() function is if(src_coord.IsUpperIndexMappedToValidOffset())
// neccessary
if(!src_coord.IsAnyLevelIndexInPaddingArea())
{ {
*reinterpret_cast<src_vector_t*>(&p_long_vector[buffer_offset]) = *reinterpret_cast<src_vector_t*>(&p_long_vector[buffer_offset]) =
*reinterpret_cast<const src_vector_t*>(&p_src[src_coord.GetOffset()]); *reinterpret_cast<const src_vector_t*>(&p_src[src_coord.GetOffset()]);
...@@ -1561,14 +1553,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -1561,14 +1553,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
const index_t dst_linear_offset = const index_t dst_linear_offset =
dst_coord.GetOffset() - dst_nonlinear_coord.GetOffset(); dst_coord.GetOffset() - dst_nonlinear_coord.GetOffset();
// Check dst vector's padding situation, only check the first data in // Check dst vector's padding situation, only check the first data in
// this dst vector. It's user's responsiblity to make sure all data in // this dst vector. It's user's responsiblity to make sure all data in
// the dst vector has the same padding situation // the dst vector has the same padding situation
// TODO: not sure a dedicated IsAnyLevelIndexInPaddingArea() function is if(dst_coord.IsUpperIndexMappedToValidOffset())
// neccessary
#if 0 // tuning
if(!dst_coord.IsAnyLevelIndexInPaddingArea())
#endif
{ {
static_if<DstAddressSpace == address_space_t::global>{}([&](auto) { static_if<DstAddressSpace == address_space_t::global>{}([&](auto) {
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE #if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
......
...@@ -110,8 +110,7 @@ struct ArrayElementPicker ...@@ -110,8 +110,7 @@ struct ArrayElementPicker
__host__ __device__ explicit constexpr ArrayElementPicker(Arr& array) : mArray{array} __host__ __device__ explicit constexpr ArrayElementPicker(Arr& array) : mArray{array}
{ {
constexpr index_t imax = constexpr index_t imax = reduce_on_sequence(Picks{}, math::maxer<index_t>{}, Number<0>{});
accumulate_on_sequence(Picks{}, math::maxer<index_t>{}, Number<0>{});
static_assert(imax < Arr::Size(), "wrong! exceeding # array element"); static_assert(imax < Arr::Size(), "wrong! exceeding # array element");
} }
......
...@@ -25,6 +25,12 @@ struct swallow ...@@ -25,6 +25,12 @@ struct swallow
} }
}; };
template <typename T>
struct logical_not
{
constexpr bool operator()(const T& x) const { return !x; }
};
// Emulate if constexpr // Emulate if constexpr
template <bool> template <bool>
struct static_if; struct static_if;
......
...@@ -764,12 +764,12 @@ __host__ __device__ constexpr auto pick_sequence_elements_by_mask(Seq, Mask) ...@@ -764,12 +764,12 @@ __host__ __device__ constexpr auto pick_sequence_elements_by_mask(Seq, Mask)
#endif #endif
template <typename Seq, typename Reduce> template <typename Seq, typename Reduce>
struct lambda_accumulate_on_sequence struct lambda_reduce_on_sequence
{ {
const Reduce& f; const Reduce& f;
index_t& result; index_t& result;
__host__ __device__ constexpr lambda_accumulate_on_sequence(const Reduce& f_, index_t& result_) __host__ __device__ constexpr lambda_reduce_on_sequence(const Reduce& f_, index_t& result_)
: f(f_), result(result_) : f(f_), result(result_)
{ {
} }
...@@ -783,14 +783,42 @@ struct lambda_accumulate_on_sequence ...@@ -783,14 +783,42 @@ struct lambda_accumulate_on_sequence
template <typename Seq, typename Reduce, index_t Init> template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr index_t __host__ __device__ constexpr index_t
accumulate_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/) reduce_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
{ {
index_t result = Init; index_t result = Init;
static_for<0, Seq::mSize, 1>{}(lambda_accumulate_on_sequence<Seq, Reduce>(f, result)); static_for<0, Seq::Size(), 1>{}(lambda_reduce_on_sequence<Seq, Reduce>(f, result));
return result; return result;
} }
// TODO: a generic any_of for any container
template <typename Seq, typename F>
__host__ __device__ constexpr bool sequence_any_of(Seq, F f /*initial_value*/)
{
bool flag = false;
for(index_t i = 0; i < Seq::Size(); ++i)
{
flag = flag || f(Seq::At(i));
}
return flag;
}
// TODO: a generic all_of for any container
template <typename Seq, typename F>
__host__ __device__ constexpr bool sequence_all_of(Seq, F f /*initial_value*/)
{
bool flag = true;
for(index_t i = 0; i < Seq::Size(); ++i)
{
flag = flag && f(Seq::At(i));
}
return flag;
}
} // namespace ck } // namespace ck
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment