Commit 8e51f990 authored by Jing Zhang's avatar Jing Zhang
Browse files

opt

parent bb37eb69
......@@ -493,13 +493,13 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int main(int argc, char* argv[])
{
constexpr index_t HStride = 1;
constexpr index_t WStride = 1;
constexpr index_t HStride = 2;
constexpr index_t WStride = 2;
constexpr index_t HDilation = 1;
constexpr index_t WDilation = 1;
constexpr index_t Direction = 1; // 1: Forward; 2:Backward
constexpr index_t Direction = 2; // 1: Forward; 0:Backward
#if 0
constexpr index_t N = 32;
constexpr index_t C = 128;
......@@ -551,8 +551,8 @@ int main(int argc, char* argv[])
// 1x1 filter, 28x28 image
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t HI = 7;
constexpr index_t WI = 7;
constexpr index_t HI = 13;
constexpr index_t WI = 13;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 1;
......
......@@ -7,8 +7,7 @@
#include "blockwise_gemm.hip.hpp"
#include "threadwise_generic_tensor_slice_op.hip.hpp"
#define FORW 1
#define FORW 0
// define B = merge(N0, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
......@@ -81,7 +80,6 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
constexpr auto in_n_c_h_w_global_desc = OutGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = InGlobalDesc{};
#endif
// to-do: backward data: 1) ckyx: yx unfold, 2) merge cyx = e, 3 out = ek
constexpr auto wei_k_c_1_1_global_desc = WeiGlobalDesc{};
......@@ -138,6 +136,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
#if FORW
constexpr auto in_lengths_new = Sequence<N0, N1, N2, Ho, Wo>{};
constexpr auto in_strides_new =
Sequence<in_n0_n1_n2_h_w_global_desc.GetStride(I0),
in_n0_n1_n2_h_w_global_desc.GetStride(I1),
......@@ -152,7 +151,6 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
constexpr auto in_n0_n1_n2_h_w_new_global_desc = in_n0_n1_n2_h_w_global_desc;
#endif
// batch descritpor for device memory
// to-do: add dilation: keep lengths, modify strides
constexpr auto in_c_y_x_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Y>{})
......@@ -349,7 +347,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
#if 0
#if 1
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
#else
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
......@@ -380,7 +378,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// even iteration
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
#if 0
#if 1
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
#else
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment