Commit 0b10c0bb authored by Jing Zhang's avatar Jing Zhang
Browse files

add backward

parent e4d2fc6f
...@@ -77,14 +77,14 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc, ...@@ -77,14 +77,14 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
constexpr index_t InBlockCopySrcDataPerRead_B = 1; constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; using WeiBlockCopySubLengths_E_K = Sequence<1, 4>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; using WeiBlockCopyClusterLengths_E_K = Sequence<8, 32>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 4;
#endif #endif
constexpr index_t GridSize = constexpr index_t GridSize =
......
...@@ -110,7 +110,7 @@ template <class TIn, ...@@ -110,7 +110,7 @@ template <class TIn,
class UpperPads, class UpperPads,
class Strides, class Strides,
class Dilations> class Dilations>
void host_direct_convolution(const Tensor<TIn>& in_nchw, void host_direct_convolution_forw(const Tensor<TIn>& in_nchw,
const Tensor<TWei>& wei_kcyx, const Tensor<TWei>& wei_kcyx,
Tensor<TOut>& out_nkhw, Tensor<TOut>& out_nkhw,
LowerPads, LowerPads,
...@@ -160,6 +160,66 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw, ...@@ -160,6 +160,66 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
f_par(std::thread::hardware_concurrency()); f_par(std::thread::hardware_concurrency());
} }
template <class TIn,
class TWei,
class TOut,
class LowerPads,
class UpperPads,
class Strides,
class Dilations>
void host_direct_convolution_back(Tensor<TOut>& in_nchw,
const Tensor<TWei>& wei_kcyx,
const Tensor<TIn>& out_nkhw,
LowerPads,
UpperPads,
Strides,
Dilations
)
{
index_t h_pad_low = LowerPads{}.Get(Number<0>{});
index_t w_pad_low = LowerPads{}.Get(Number<1>{});
index_t h_pad_up = UpperPads{}.Get(Number<0>{});
index_t w_pad_up = UpperPads{}.Get(Number<1>{});
index_t stride_h = Strides{}.Get(Number<0>{});
index_t stride_w = Strides{}.Get(Number<1>{});
index_t dilation_h = Dilations{}.Get(Number<0>{});
index_t dilation_w = Dilations{}.Get(Number<1>{});
//loop n,c,hi,wi
auto f = [&](auto n, auto c, auto hi, auto wi) {
double v = 0;
//loop k,y,x
for(int k = 0; k < wei_kcyx.mDesc.GetLengths()[0]; ++k)
{
for(int y = 0; y < wei_kcyx.mDesc.GetLengths()[2]; ++y)
{
int ho = (hi - y * dilation_h + h_pad_low) / stride_h;
for(int x = 0; x < wei_kcyx.mDesc.GetLengths()[3]; ++x)
{
int wo = (wi - x * dilation_w + w_pad_low) / stride_w;
if(ho >= 0 && hi < out_nkhw.mDesc.GetLengths()[2] && wo >= 0 &&
wo < out_nkhw.mDesc.GetLengths()[3] && ho % stride_h == 0 && wo % stride_w == 0)
{
v += double(out_nkhw(n, k, ho, wo)) * double(wei_kcyx(k, c, y, x));
}
}
}
}
in_nchw(n, c, hi, wi) = v;
};
auto f_par = make_ParallelTensorFunctor(f,
in_nchw.mDesc.GetLengths()[0],
in_nchw.mDesc.GetLengths()[1],
in_nchw.mDesc.GetLengths()[2],
in_nchw.mDesc.GetLengths()[3]);
f_par(std::thread::hardware_concurrency());
}
template <class TIn, class TWei, class TOut, class LowerPads, class UpperPads> template <class TIn, class TWei, class TOut, class LowerPads, class UpperPads>
void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw, void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
const Tensor<TWei>& wei_kcyx, const Tensor<TWei>& wei_kcyx,
...@@ -422,16 +482,18 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result) ...@@ -422,16 +482,18 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
constexpr index_t U = 1; constexpr index_t HStride = 2;
constexpr index_t V = 1; constexpr index_t WStride = 2;
constexpr index_t Dh = 2; constexpr index_t HDilation = 1;
constexpr index_t Dw = 2; constexpr index_t WDilation = 1;
#if 0
constexpr index_t Direction = 2; //1: Forward; 2:Backward
#if 1
constexpr index_t N = 8; constexpr index_t N = 8;
constexpr index_t C = 16; constexpr index_t C = 128;
constexpr index_t HI = 20; constexpr index_t HI = 16;
constexpr index_t WI = 20; constexpr index_t WI = 16;
constexpr index_t K = 128; constexpr index_t K = 128;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
...@@ -462,7 +524,7 @@ int main(int argc, char* argv[]) ...@@ -462,7 +524,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 1 #elif 0
// 3x3 filter, 28x28 image // 3x3 filter, 28x28 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 256; constexpr index_t C = 256;
...@@ -474,7 +536,7 @@ int main(int argc, char* argv[]) ...@@ -474,7 +536,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 1
// 1x1 filter, 28x28 image // 1x1 filter, 28x28 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 512; constexpr index_t C = 512;
...@@ -599,14 +661,16 @@ int main(int argc, char* argv[]) ...@@ -599,14 +661,16 @@ int main(int argc, char* argv[])
auto lower_pads = Sequence<HPad, WPad>{}; auto lower_pads = Sequence<HPad, WPad>{};
auto upper_pads = Sequence<HPad, WPad>{}; auto upper_pads = Sequence<HPad, WPad>{};
auto strides = Sequence<U, V>{}; auto strides = Sequence<HStride, WStride>{};
auto dilations = Sequence<Dh, Dw>{}; auto dilations = Sequence<HDilation, WDilation>{};
auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{}); auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{});
auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<K, C, Y, X>{}); auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<K, C, Y, X>{});
auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor( auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor(
in_nchw_desc, wei_kcyx_desc, strides, dilations); in_nchw_desc, wei_kcyx_desc, strides, dilations);
auto wei_ckyx_back_desc = wei_kcyx_desc.ReorderGivenNew2Old(Sequence<1, 0, 2, 3>{});
ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
ostream_ConstantTensorDescriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: "); ostream_ConstantTensorDescriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
...@@ -614,9 +678,10 @@ int main(int argc, char* argv[]) ...@@ -614,9 +678,10 @@ int main(int argc, char* argv[])
using in_data_t = float; using in_data_t = float;
using out_data_t = float; using out_data_t = float;
Tensor<in_data_t> in_nchw(make_TensorDescriptor(in_nchw_desc)); Tensor<in_data_t> in_nchw(make_TensorDescriptor(in_nchw_desc));
Tensor<in_data_t> wei_kcyx(make_TensorDescriptor(wei_kcyx_desc)); Tensor<in_data_t> out_nkhw(make_TensorDescriptor(out_nkhw_desc));
Tensor<out_data_t> out_nkhw_host(make_TensorDescriptor(out_nkhw_desc)); Tensor<in_data_t> in_nchw_device(make_TensorDescriptor(in_nchw_desc));
Tensor<out_data_t> out_nkhw_device(make_TensorDescriptor(out_nkhw_desc)); Tensor<out_data_t> out_nkhw_device(make_TensorDescriptor(out_nkhw_desc));
Tensor<in_data_t> wei_kcyx(make_TensorDescriptor(wei_kcyx_desc));
std::size_t num_thread = std::thread::hardware_concurrency(); std::size_t num_thread = std::thread::hardware_concurrency();
...@@ -642,6 +707,7 @@ int main(int argc, char* argv[]) ...@@ -642,6 +707,7 @@ int main(int argc, char* argv[])
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 1 #elif 1
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#elif 0 #elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
...@@ -673,15 +739,16 @@ int main(int argc, char* argv[]) ...@@ -673,15 +739,16 @@ int main(int argc, char* argv[])
#elif 1 #elif 1
device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw
#endif #endif
(in_nchw_desc, (out_nkhw_desc,
in_nchw, out_nkhw,
wei_kcyx_desc, wei_ckyx_back_desc,
wei_kcyx, wei_kcyx,
out_nkhw_desc, in_nchw_desc,
strides, strides,
dilations, dilations,
out_nkhw_device, in_nchw_device,
nrepeat); nrepeat
);
#elif 1 #elif 1
device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(in_nchw_desc, device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(in_nchw_desc,
...@@ -704,11 +771,18 @@ int main(int argc, char* argv[]) ...@@ -704,11 +771,18 @@ int main(int argc, char* argv[])
} }
else else
#endif #endif
if(Direction == 1)
{
host_direct_convolution_forw(
in_nchw, wei_kcyx, out_nkhw, lower_pads, upper_pads, strides, dilations);
check_error(out_nkhw, out_nkhw_device);
}
else
{ {
host_direct_convolution( host_direct_convolution_back(
in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads, strides, dilations); in_nchw, wei_kcyx, out_nkhw, lower_pads, upper_pads, strides, dilations);
check_error(in_nchw, in_nchw_device);
} }
check_error(out_nkhw_host, out_nkhw_device);
#if 0 #if 0
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl; LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
......
...@@ -112,48 +112,49 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -112,48 +112,49 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// input tensor // input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo] // tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
// constexpr auto in_n0_n1_n2_h_w_global_desc = in_n_c_h_w_global_desc.Slice(I2, constexpr auto in_n0_n1_n2_h_w_global_desc = in_n_c_h_w_global_desc.Slice(I2,
// Number<Ho>{}) Number<Ho>{})
//.Slice(I3, Number<Wo>{}) .Slice(I3, Number<Wo>{})
//.Fold(I0, Number<N1>{}, Number<N2>{}) .Fold(I0, Number<N1>{}, Number<N2>{})
//.Extract(Sequence<0, 1, 2, 4, 5>{}); .Extract(Sequence<0, 1, 2, 4, 5>{});
constexpr auto in_n0_n1_n2_h_w_global_desc = //constexpr auto in_n0_n1_n2_h_w_global_desc =
in_n_c_h_w_global_desc.Fold(I0, Number<N1>{}, Number<N2>{}) //in_n_c_h_w_global_desc.Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{}); //.Extract(Sequence<0, 1, 2, 4, 5>{});
constexpr auto in_lengths_new = Sequence<N0, N1, N2, Ho, Wo>{}; //constexpr auto in_lengths_new = Sequence<N0, N1, N2, Ho, Wo>{};
constexpr auto in_strides_new = //constexpr auto in_strides_new =
Sequence<in_n0_n1_n2_h_w_global_desc.GetStride(I0), //Sequence<in_n0_n1_n2_h_w_global_desc.GetStride(I0),
in_n0_n1_n2_h_w_global_desc.GetStride(I1), //in_n0_n1_n2_h_w_global_desc.GetStride(I1),
in_n0_n1_n2_h_w_global_desc.GetStride(I2), //in_n0_n1_n2_h_w_global_desc.GetStride(I2),
in_n0_n1_n2_h_w_global_desc.GetStride(I3) * Strides{}.Get(I0), //in_n0_n1_n2_h_w_global_desc.GetStride(I3),
in_n0_n1_n2_h_w_global_desc.GetStride(I4) * Strides{}.Get(I1)>{}; //in_n0_n1_n2_h_w_global_desc.GetStride(I4)>{};
constexpr auto in_n0_n1_n2_h_w_new_global_desc = //constexpr auto in_n0_n1_n2_h_w_new_global_desc =
make_ConstantTensorDescriptor(in_lengths_new, in_strides_new); //make_ConstantTensorDescriptor(in_lengths_new, in_strides_new);
constexpr auto in_n0_n1_n2_h_w_new_global_desc = in_n0_n1_n2_h_w_global_desc;
// batch descritpor for device memory // batch descritpor for device memory
// to-do: add dilation: keep lengths, modify strides // to-do: add dilation: keep lengths, modify strides
constexpr auto in_c_y_x_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Y>{}) constexpr auto in_c_y_x_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Y>{})
.Slice(I3, Number<X>{}) .Slice(I3, Number<X>{})
.Extract(Sequence<1, 2, 3>{}); .Extract(Sequence<1, 2, 3>{});
constexpr auto in_win_lengths_new = Sequence<in_c_y_x_global_desc.GetLength(I0), //constexpr auto in_win_lengths_new = Sequence<in_c_y_x_global_desc.GetLength(I0),
in_c_y_x_global_desc.GetLength(I1), //in_c_y_x_global_desc.GetLength(I1),
in_c_y_x_global_desc.GetLength(I2)>{}; //in_c_y_x_global_desc.GetLength(I2)>{};
constexpr auto in_win_strides_new = //constexpr auto in_win_strides_new =
Sequence<in_c_y_x_global_desc.GetStride(I0), //Sequence<in_c_y_x_global_desc.GetStride(I0),
in_c_y_x_global_desc.GetStride(I1) * Dilations{}.Get(I0), //in_c_y_x_global_desc.GetStride(I1),
in_c_y_x_global_desc.GetStride(I2) * Dilations{}.Get(I1)>{}; //in_c_y_x_global_desc.GetStride(I2)>{};
constexpr auto in_c_y_x_new_global_desc = //constexpr auto in_c_y_x_new_global_desc =
make_ConstantTensorDescriptor(in_win_lengths_new, in_win_strides_new); //make_ConstantTensorDescriptor(in_win_lengths_new, in_win_strides_new);
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy // merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
in_c_y_x_new_global_desc.Embed(in_n0_n1_n2_h_w_new_global_desc), in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_new_global_desc),
Sequence<0, 1, 2>{}, Sequence<0, 1, 2>{},
Sequence<4>{}, Sequence<4>{},
Sequence<3, 6, 7>{}, Sequence<3, 6, 7>{},
...@@ -190,8 +191,15 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -190,8 +191,15 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// weight tensor // weight tensor
// tensor descriptor in device memory, src of blockwise copy // tensor descriptor in device memory, src of blockwise copy
#if 0
constexpr auto wei_e_k_global_desc = constexpr auto wei_e_k_global_desc =
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{}); wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
#else
constexpr auto wei_e_k_global_desc =
make_ConstantMergedTensorDescriptor(wei_k_c_y_x_global_desc,
Sequence<1, 2, 3>{},
Sequence<0>{});
#endif
// tensor descriptor in LDS, dst of blockwise copy // tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
...@@ -356,7 +364,11 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -356,7 +364,11 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True); blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
#if 0
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
#else
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
#endif
__syncthreads(); __syncthreads();
...@@ -383,7 +395,11 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -383,7 +395,11 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// even iteration // even iteration
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True); blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
#if 0
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
#else
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
#endif
__syncthreads(); __syncthreads();
...@@ -432,6 +448,34 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -432,6 +448,34 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}, Number<K2>{}) out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}, Number<K2>{})
.Fold(I0, Number<N1>{}, Number<N2>{}); .Fold(I0, Number<N1>{}, Number<N2>{});
constexpr auto out_lengths_new = Sequence<
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I0),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I1),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I2),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I3),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I4),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I5),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I6) / Strides{}.Get(I0),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I7) / Strides{}.Get(I1)
>{};
constexpr auto out_strides_new = Sequence<
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I0),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I1),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I2),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I3),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I4),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I5),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I6) * Strides{}.Get(I0),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I7) * Strides{}.Get(I1)
>{};
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_new_global_mem_desc = make_ConstantTensorDescriptor(
out_lengths_new, out_strides_new
);
// calculate origin of thread output tensor on global memory // calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index // blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
...@@ -446,7 +490,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -446,7 +490,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// output merged global tensor descriptor, for calculating origin of thread tensor // output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory // in global memory
constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5), out_n0_n1_n2_k0_k1_k2_h_w_new_global_mem_desc.Unfold(I3, I5),
Sequence<3>{}, Sequence<3>{},
Sequence<1>{}, Sequence<1>{},
Sequence<0, 4, 5>{}, Sequence<0, 4, 5>{},
...@@ -462,7 +506,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -462,7 +506,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc, out_n0_n1_n2_k0_k1_k2_h_w_thread_desc,
p_out_thread, p_out_thread,
{0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0},
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc, out_n0_n1_n2_k0_k1_k2_h_w_new_global_mem_desc,
p_out_thread_on_global, p_out_thread_on_global,
{0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0},
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(), out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment