Commit 4a1e97cf authored by Chao Liu's avatar Chao Liu
Browse files

tweak

parent c2d24669
...@@ -155,7 +155,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -155,7 +155,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0, static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
"GemmDataPerReadB alignment requirement is not satisfied"); "GemmDataPerReadB alignment requirement is not satisfied");
#if 1 // debug #if 1
// input blockwise copy // input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor // slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
...@@ -198,7 +198,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -198,7 +198,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
Sequence<EPerBlock, KPerBlock>{}, Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{}); Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
#if 1 // debug #if 1
// operator for blockwise copy of weight into LDS // operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
...@@ -324,10 +324,12 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -324,10 +324,12 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
#if 1 #if 1
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True); blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
// blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{},
// True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
#else #else
blockwise_in_copy.MoveSrcSlicingWindow({EPerBlock, 0, 0, 0}, true); blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
blockwise_wei_copy.MoveSrcSlicingWindow({EPerBlock, 0}, true); blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
#endif #endif
__syncthreads(); __syncthreads();
...@@ -348,16 +350,17 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -348,16 +350,17 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// LDS double buffer: tail // LDS double buffer: tail
{ {
// even iteration
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
// even iteration
#if 1 #if 1
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True); blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
// blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
#else #else
blockwise_in_copy.MoveSrcSlicingWindow({EPerBlock, 0, 0, 0}, true); blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
blockwise_wei_copy.MoveSrcSlicingWindow({EPerBlock, 0}, true); blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
#endif #endif
__syncthreads(); __syncthreads();
...@@ -431,7 +434,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -431,7 +434,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex( out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
k_thread_data_on_global, 0, b_thread_data_on_global, 0); k_thread_data_on_global, 0, b_thread_data_on_global, 0);
#if 1 // debug #if 1
threadwise_generic_tensor_slice_copy_v1( threadwise_generic_tensor_slice_copy_v1(
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc, out_n0_n1_n2_k0_k1_k2_h_w_thread_desc,
p_out_thread, p_out_thread,
......
...@@ -125,12 +125,15 @@ struct MergedTensorCoordinate ...@@ -125,12 +125,15 @@ struct MergedTensorCoordinate
__host__ __device__ constexpr index_t GetOffset() const { return mOffset; } __host__ __device__ constexpr index_t GetOffset() const { return mOffset; }
// step_size should be known at compile time template <class IDim, class T, bool PositiveDirection>
template <class IDim, bool PositiveDirection>
__host__ __device__ void __host__ __device__ void
MoveOnDimension(IDim, index_t step_size, integral_constant<bool, PositiveDirection>) MoveOnDimension(IDim idim_, T step_size, integral_constant<bool, PositiveDirection>)
{ {
constexpr auto idim = IDim{}; constexpr auto idim = idim_;
// if step_size is known at compile time
static_if<is_static<T>::value>{}(
[&](auto) { static_if<T{} == 0>{}([&](auto) { return; }); });
// update original index // update original index
static_if<tensor_desc_type::ContainMultipleOriginalDimensions(idim)>{}([&](auto) { static_if<tensor_desc_type::ContainMultipleOriginalDimensions(idim)>{}([&](auto) {
......
...@@ -446,14 +446,18 @@ struct BlockwiseGenericTensorSliceCopy_v2 ...@@ -446,14 +446,18 @@ struct BlockwiseGenericTensorSliceCopy_v2
mThreadwiseStore.Run(p_buffer, p_dst); mThreadwiseStore.Run(p_buffer, p_dst);
} }
__device__ void MoveSrcSlicingWindow(Array<index_t, nDim> step_sizes, bool positive_direction) template <class T, bool PositiveDirection>
__device__ void MoveSrcSlicingWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
{ {
mThreadwiseLoad.MoveSrcSlicingWindow(step_sizes, positive_direction); mThreadwiseLoad.MoveSrcSlicingWindow(step_sizes,
integral_constant<bool, PositiveDirection>{});
} }
__device__ void MoveDstSlicingWindow(Array<index_t, nDim> step_sizes, bool positive_direction) template <class T, bool PositiveDirection>
__device__ void MoveDstSlicingWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
{ {
mThreadwiseStore.MoveDstSlicingWindow(step_sizes, positive_direction); mThreadwiseLoad.MoveDstSlicingWindow(step_sizes,
integral_constant<bool, PositiveDirection>{});
} }
private: private:
......
...@@ -216,28 +216,20 @@ struct ThreadwiseGenericTensorSliceCopy_v2 ...@@ -216,28 +216,20 @@ struct ThreadwiseGenericTensorSliceCopy_v2
}); });
} }
__device__ void MoveSrcSlicingWindow(Array<index_t, nDim> step_sizes, bool positive_direction) template <class T, bool PositiveDirection>
{ __device__ void MoveSrcSlicingWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
if(positive_direction)
{ {
static_if<PositiveDirection>{}([&](auto) {
mSrcSliceOrigin += step_sizes; mSrcSliceOrigin += step_sizes;
} }).Else([&](auto) { mSrcSliceOrigin -= step_sizes; });
else
{
mSrcSliceOrigin -= step_sizes;
}
} }
__device__ void MoveDstSlicingWindow(Array<index_t, nDim> step_sizes, bool positive_direction) template <class T, bool PositiveDirection>
{ __device__ void MoveDstSlicingWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
if(positive_direction)
{
mDstSliceOrigin += step_sizes;
}
else
{ {
static_if<PositiveDirection>([&](auto) { mDstSliceOrigin += step_sizes; }).Else([&](auto) {
mDstSliceOrigin -= step_sizes; mDstSliceOrigin -= step_sizes;
} });
} }
// private: // private:
......
...@@ -8,6 +8,21 @@ ...@@ -8,6 +8,21 @@
namespace ck { namespace ck {
template <class>
struct is_static : integral_constant<bool, false>
{
};
template <class T, T X>
struct is_static<integral_constant<T, X>> : integral_constant<bool, true>
{
};
template <index_t... Is>
struct is_static<Sequence<Is...>> : integral_constant<bool, true>
{
};
// RemainLengths: Sequence<...> // RemainLengths: Sequence<...>
template <class RemainLengths> template <class RemainLengths>
struct static_ford_impl struct static_ford_impl
......
...@@ -379,7 +379,7 @@ int main(int argc, char* argv[]) ...@@ -379,7 +379,7 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw( device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
#elif 0 #elif 1
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment