"docs/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "8c262c14b66395b91b7a52782b02aff8fe3fb20d"
Commit e9733a9f authored by Chao Liu's avatar Chao Liu
Browse files

experimenting TensorCoordinate and new merged tensor copy operator

parent b9663356
...@@ -299,7 +299,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -299,7 +299,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
blockwise_in_copy.Run(p_in_global, p_in_block); blockwise_in_copy.Run(p_in_global, p_in_block);
blockwise_wei_copy.Run(p_wei_global, p_wei_block); blockwise_wei_copy.Run(p_wei_global, p_wei_block);
#else #else
using InSrcMergedDimSubLengthsHack = Sequence<1, 1, 1, 1>; using InSrcMergedDimSubLengthsHack = Sequence<InBlockCopySubLengths_E_N1_B_N2{}[0],
1,
InBlockCopySubLengths_E_N1_B_N2{}[2],
1>;
using InDstMergedDimSubLengthsHack = Sequence<1, 1, 1, 1>; using InDstMergedDimSubLengthsHack = Sequence<1, 1, 1, 1>;
blockwise_in_copy.Run_hack(p_in_global, blockwise_in_copy.Run_hack(p_in_global,
p_in_block, p_in_block,
...@@ -388,6 +391,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -388,6 +391,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
arithmetic_sequence_gen<0, 8, 1>::type{}, arithmetic_sequence_gen<0, 8, 1>::type{},
Number<1>{}); Number<1>{});
#else #else
using OutSrcMergedDimSliceLengthsHack = Sequence<1, 1, 1, 1, 1, 1, 1, 1>;
using OutDstMergedDimSliceLengthsHack = Sequence<1, 1, 1, 1, 1, 1, 1, 1>;
ThreadwiseGenericTensorSliceCopy_v2< ThreadwiseGenericTensorSliceCopy_v2<
Float, Float,
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc), decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc),
...@@ -396,7 +403,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw ...@@ -396,7 +403,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
MergedTensorCoordinate<decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc)>, MergedTensorCoordinate<decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc)>,
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths())>( decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths())>(
{0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0}) {0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0})
.Run(p_out_thread, p_out_thread_on_global); .Run_hack(p_out_thread,
p_out_thread_on_global,
OutSrcMergedDimSliceLengthsHack{},
OutDstMergedDimSliceLengthsHack{});
#endif #endif
} }
} }
......
...@@ -237,8 +237,24 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -237,8 +237,24 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
for(index_t e_block_data_begin = 0; e_block_data_begin < E; e_block_data_begin += EPerBlock) for(index_t e_block_data_begin = 0; e_block_data_begin < E; e_block_data_begin += EPerBlock)
{ {
#if 0
blockwise_in_copy.Run(p_in_global, p_in_block); blockwise_in_copy.Run(p_in_global, p_in_block);
blockwise_wei_copy.Run(p_wei_global, p_wei_block); blockwise_wei_copy.Run(p_wei_global, p_wei_block);
#else
using InSrcMergedDimSubLengthsHack = InBlockCopySubLengths_E_B;
using InDstMergedDimSubLengthsHack = Sequence<1, 1>;
blockwise_in_copy.Run_hack(p_in_global,
p_in_block,
InSrcMergedDimSubLengthsHack{},
InDstMergedDimSubLengthsHack{});
using WeiSrcMergedDimSubLengthsHack = Sequence<1, 1>;
using WeiDstMergedDimSubLengthsHack = Sequence<1, 1>;
blockwise_wei_copy.Run_hack(p_wei_global,
p_wei_block,
WeiSrcMergedDimSubLengthsHack{},
WeiDstMergedDimSubLengthsHack{});
#endif
__syncthreads(); __syncthreads();
...@@ -272,36 +288,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -272,36 +288,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
const index_t b_thread_data_on_global = const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col; b_block_data_on_global + c_thread_mtx_on_block.col;
#if 0
// origin of dst in device memory
Float* p_out_thread_on_global = p_out_global +
out_k_b_global_desc.GetOffsetFromMultiIndex(
k_thread_data_on_global, b_thread_data_on_global);
// dst descriptor
constexpr auto out_k0_k1_b0_b1_global_desc =
out_k_b_global_desc.Fold(I1, Number<B1>{}).Fold(I0, Number<K1>{});
// src descriptor
constexpr auto out_k0_k1_b0_b1_thread_desc = make_ConstantTensorDescriptor_packed(
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNRepeat, GemmNPerThreadSubC>{});
const auto threadwise_out_copy =
ThreadwiseGenericTensorSliceCopy_v2<Float,
decltype(out_k0_k1_b0_b1_thread_desc),
decltype(out_k0_k1_b0_b1_global_desc),
decltype(
out_k0_k1_b0_b1_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 4, 1>::type,
1,
1>({0, 0, 0, 0},
{k_thread_data_on_global / K1,
k_thread_data_on_global % K1,
b_thread_data_on_global / B1,
b_thread_data_on_global % B1});
threadwise_out_copy.Run(p_out_thread, p_out_thread_on_global);
#elif 1
// This is a hack, because slicing a merged dimension is not supported yet. // This is a hack, because slicing a merged dimension is not supported yet.
// This should be replaced with logic above, once slicing a merged dimension support // This should be replaced with logic above, once slicing a merged dimension support
// become available // become available
...@@ -316,35 +302,37 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -316,35 +302,37 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
constexpr auto out_k0_k1_b_thread_desc = make_ConstantTensorDescriptor_packed( constexpr auto out_k0_k1_b_thread_desc = make_ConstantTensorDescriptor_packed(
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNRepeat * GemmNPerThreadSubC>{}); Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNRepeat * GemmNPerThreadSubC>{});
using OutThreadCopySliceLengths =
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNPerThreadSubC>;
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2< auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2<
Float, Float,
#if 1 // debug
decltype(out_k0_k1_b_thread_desc), decltype(out_k0_k1_b_thread_desc),
decltype(out_k0_k1_b_global_desc), decltype(out_k0_k1_b_global_desc),
NormalTensorCoordinate<decltype(out_k0_k1_b_thread_desc)>, NormalTensorCoordinate<decltype(out_k0_k1_b_thread_desc)>,
MergedTensorCoordinate<decltype(out_k0_k1_b_global_desc)>, MergedTensorCoordinate<decltype(out_k0_k1_b_global_desc)>,
#else OutThreadCopySliceLengths>({0, 0, 0},
decltype(out_k0_k1_b_thread_desc), {k_thread_data_on_global / K1,
decltype( k_thread_data_on_global % K1,
make_ConstantTensorDescriptor_packed(out_k0_k1_b_global_desc.GetLengths())), b_thread_data_on_global});
NormalTensorCoordinate<decltype(out_k0_k1_b_thread_desc)>,
NormalTensorCoordinate<decltype(
make_ConstantTensorDescriptor_packed(out_k0_k1_b_global_desc.GetLengths()))>,
#endif
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNPerThreadSubC>>(
{0, 0, 0},
{k_thread_data_on_global / K1,
k_thread_data_on_global % K1,
b_thread_data_on_global});
for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat) for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat)
{ {
#if 0
threadwise_out_copy.Run(p_out_thread, p_out_global); threadwise_out_copy.Run(p_out_thread, p_out_global);
#else
using OutSrcMergedDimSubLengthsHack = Sequence<1, 1, 1>;
using OutDstMergedDimSubLengthsHack =
Sequence<1, 1, OutThreadCopySliceLengths{}[2]>;
threadwise_out_copy.Run_hack(p_out_thread,
p_out_global,
OutSrcMergedDimSubLengthsHack{},
OutDstMergedDimSubLengthsHack{});
#endif
threadwise_out_copy.MoveSrcSlicingWindow({0, 0, GemmNPerThreadSubC}, true); threadwise_out_copy.MoveSrcSlicingWindow({0, 0, GemmNPerThreadSubC}, true);
threadwise_out_copy.MoveDstSlicingWindow({0, 0, B1}, true); threadwise_out_copy.MoveDstSlicingWindow({0, 0, B1}, true);
} }
#endif
} }
} }
}; };
......
...@@ -532,7 +532,7 @@ int main(int argc, char* argv[]) ...@@ -532,7 +532,7 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw( device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
#elif 1 #elif 0
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment