"server/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "09bb2e30f69489b2bd5138fa81d9dbb54c1d2f19"
Commit eb8a1bf9 authored by Jehandad Khan's avatar Jehandad Khan
Browse files

fixed out tensor order for wrw

parent 5ebe74e6
...@@ -388,6 +388,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -388,6 +388,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// output merged global tensor descriptor, for calculating origin of thread tensor // output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory // in global memory
// JD: Even thought we changecd ghe layut of the output for wrw we dont need to change the following unfold to merge because the unfloded dimension is alredy contiguous
constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5), out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5),
Sequence<3>{}, Sequence<3>{},
...@@ -411,7 +412,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -411,7 +412,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(), out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(),
arithmetic_sequence_gen<0, 8, 1>::type{}, arithmetic_sequence_gen<0, 8, 1>::type{},
Number<1>{}); Number<1>{});
#elif 1 #elif 0
p_out_global[0] = p_out_thread[0]; p_out_global[0] = p_out_thread[0];
#endif #endif
} }
......
...@@ -62,7 +62,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -62,7 +62,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t B = (N * Ho * Wo) / (N1 * N2); constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
#if 1 #if 1
// JD: New params for wrw // JD: New params for wrw for debugging the out ptr seg fault
// each thread hold 64 data // each thread hold 64 data
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -125,8 +125,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -125,8 +125,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadA = 1;
constexpr index_t GemmDataPerReadB = 4; constexpr index_t GemmDataPerReadB = 1;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>; using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>; using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
...@@ -135,7 +135,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -135,7 +135,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1; constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; constexpr index_t InBlockCopyDstDataPerWrite_N2 = 1;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
...@@ -143,7 +143,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -143,7 +143,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif 0 #elif 0
// each thread hold 32 data // each thread hold 32 data
...@@ -202,7 +202,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -202,7 +202,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
decltype(in_nchw_desc.ReorderGivenNew2Old(Sequence<1, 0, 2, 3>{})), decltype(in_nchw_desc.ReorderGivenNew2Old(Sequence<1, 0, 2, 3>{})),
decltype(out_nkhw_desc.ReorderGivenNew2Old(Sequence<1, 0, 2, 3>{})), decltype(out_nkhw_desc.ReorderGivenNew2Old(Sequence<1, 0, 2, 3>{})),
// pass in the output instead of the weight, also reordered to knhw // pass in the output instead of the weight, also reordered to knhw
decltype(wei_kcyx_desc), decltype(wei_kcyx_desc.ReorderGivenNew2Old(Sequence<1, 0, 2, 3>{})),
// the output would be the weights, which would not be reordered // the output would be the weights, which would not be reordered
// as discussed in the morning for wrw strides and dilation switch positions // as discussed in the morning for wrw strides and dilation switch positions
ConvDilations, // wrw: becomes stride ConvDilations, // wrw: becomes stride
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment