Commit 8e51f990 authored by Jing Zhang's avatar Jing Zhang
Browse files

opt

parent bb37eb69
...@@ -493,13 +493,13 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result) ...@@ -493,13 +493,13 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
constexpr index_t HStride = 1; constexpr index_t HStride = 2;
constexpr index_t WStride = 1; constexpr index_t WStride = 2;
constexpr index_t HDilation = 1; constexpr index_t HDilation = 1;
constexpr index_t WDilation = 1; constexpr index_t WDilation = 1;
constexpr index_t Direction = 1; // 1: Forward; 2:Backward constexpr index_t Direction = 2; // 1: Forward; 0:Backward
#if 0 #if 0
constexpr index_t N = 32; constexpr index_t N = 32;
constexpr index_t C = 128; constexpr index_t C = 128;
...@@ -551,8 +551,8 @@ int main(int argc, char* argv[]) ...@@ -551,8 +551,8 @@ int main(int argc, char* argv[])
// 1x1 filter, 28x28 image // 1x1 filter, 28x28 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
constexpr index_t HI = 7; constexpr index_t HI = 13;
constexpr index_t WI = 7; constexpr index_t WI = 13;
constexpr index_t K = 128; constexpr index_t K = 128;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
......
...@@ -7,8 +7,7 @@ ...@@ -7,8 +7,7 @@
#include "blockwise_gemm.hip.hpp" #include "blockwise_gemm.hip.hpp"
#include "threadwise_generic_tensor_slice_op.hip.hpp" #include "threadwise_generic_tensor_slice_op.hip.hpp"
#define FORW 1 #define FORW 0
// define B = merge(N0, Ho, Wo) // define B = merge(N0, Ho, Wo)
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
...@@ -81,7 +80,6 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -81,7 +80,6 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
constexpr auto in_n_c_h_w_global_desc = OutGlobalDesc{}; constexpr auto in_n_c_h_w_global_desc = OutGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = InGlobalDesc{}; constexpr auto out_n_k_h_w_global_desc = InGlobalDesc{};
#endif #endif
// to-do: backward data: 1) ckyx: yx unfold, 2) merge cyx = e, 3 out = ek // to-do: backward data: 1) ckyx: yx unfold, 2) merge cyx = e, 3 out = ek
constexpr auto wei_k_c_1_1_global_desc = WeiGlobalDesc{}; constexpr auto wei_k_c_1_1_global_desc = WeiGlobalDesc{};
...@@ -138,6 +136,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -138,6 +136,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
#if FORW #if FORW
constexpr auto in_lengths_new = Sequence<N0, N1, N2, Ho, Wo>{}; constexpr auto in_lengths_new = Sequence<N0, N1, N2, Ho, Wo>{};
constexpr auto in_strides_new = constexpr auto in_strides_new =
Sequence<in_n0_n1_n2_h_w_global_desc.GetStride(I0), Sequence<in_n0_n1_n2_h_w_global_desc.GetStride(I0),
in_n0_n1_n2_h_w_global_desc.GetStride(I1), in_n0_n1_n2_h_w_global_desc.GetStride(I1),
...@@ -152,7 +151,6 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -152,7 +151,6 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
constexpr auto in_n0_n1_n2_h_w_new_global_desc = in_n0_n1_n2_h_w_global_desc; constexpr auto in_n0_n1_n2_h_w_new_global_desc = in_n0_n1_n2_h_w_global_desc;
#endif #endif
// batch descritpor for device memory // batch descritpor for device memory
// to-do: add dilation: keep lengths, modify strides // to-do: add dilation: keep lengths, modify strides
constexpr auto in_c_y_x_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Y>{}) constexpr auto in_c_y_x_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Y>{})
...@@ -349,7 +347,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -349,7 +347,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True); blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
#if 0 #if 1
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
#else #else
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True); blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
...@@ -380,7 +378,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -380,7 +378,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// even iteration // even iteration
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True); blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
#if 0 #if 1
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
#else #else
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True); blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment