Commit ffa7e4be authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent fc7a6c85
...@@ -81,16 +81,15 @@ map_convolution_into_gemm_v2(const WeiDesc& wei_k_c_y_x_global_desc, ...@@ -81,16 +81,15 @@ map_convolution_into_gemm_v2(const WeiDesc& wei_k_c_y_x_global_desc,
} }
template <index_t BlockSize> template <index_t BlockSize>
struct DummyDynamicTransform_v2 struct DummyDynamicTransform_v2_1
{ {
template <typename WeiDesc, typename InDesc, typename OutDesc, typename TransformInDesc> template <typename WeiDesc, typename InDesc, typename OutDesc>
__device__ void Run_1(index_t* const __restrict__ p_wei_global, __device__ void Run_1(index_t* const __restrict__ p_wei_global,
float* const __restrict__ p_in_global, float* const __restrict__ p_in_global,
float* const __restrict__ p_out_global, float* const __restrict__ p_out_global,
const WeiDesc wei_k_c_y_x_global_desc, const WeiDesc wei_k_c_y_x_global_desc,
const InDesc in_n_c_hi_wi_global_desc, const InDesc in_n_c_hi_wi_global_desc,
const OutDesc out_n_k_ho_wo_global_desc, const OutDesc out_n_k_ho_wo_global_desc,
const TransformInDesc /* in_gemmk_gemmn_global_desc */,
const Array<index_t, 2> conv_strides, const Array<index_t, 2> conv_strides,
const Array<index_t, 2> conv_dilations, const Array<index_t, 2> conv_dilations,
const Array<index_t, 2> in_left_pads, const Array<index_t, 2> in_left_pads,
...@@ -131,14 +130,13 @@ struct DummyDynamicTransform_v2 ...@@ -131,14 +130,13 @@ struct DummyDynamicTransform_v2
} }
} }
template <typename WeiDesc, typename InDesc, typename OutDesc, typename TransformInDesc> template <typename WeiDesc, typename InDesc, typename OutDesc>
__device__ void Run_2(index_t* const __restrict__ p_wei_global, __device__ void Run_2(index_t* const __restrict__ p_wei_global,
float* const __restrict__ p_in_global, float* const __restrict__ p_in_global,
float* const __restrict__ p_out_global, float* const __restrict__ p_out_global,
const WeiDesc wei_k_c_y_x_global_desc, const WeiDesc wei_k_c_y_x_global_desc,
const InDesc in_n_c_hi_wi_global_desc, const InDesc in_n_c_hi_wi_global_desc,
const OutDesc out_n_k_ho_wo_global_desc, const OutDesc out_n_k_ho_wo_global_desc,
const TransformInDesc /* in_gemmk_gemmn_global_desc */,
const Array<index_t, 2> conv_strides, const Array<index_t, 2> conv_strides,
const Array<index_t, 2> conv_dilations, const Array<index_t, 2> conv_dilations,
const Array<index_t, 2> in_left_pads, const Array<index_t, 2> in_left_pads,
...@@ -187,21 +185,21 @@ struct DummyDynamicTransform_v2 ...@@ -187,21 +185,21 @@ struct DummyDynamicTransform_v2
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
#else #else
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor_v2( const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor_v2(
transform_dynamic_tensor_descriptor_v2( transform_dynamic_tensor_descriptor_v2(
move(in_n_c_hi_wi_global_desc), move(in_n_c_hi_wi_global_desc),
make_tuple(DynamicPassThrough{N},
DynamicPassThrough{C},
DynamicLeftPad{Hi, InLeftPadH},
DynamicLeftPad{Wi, InLeftPadW}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})),
make_tuple(DynamicPassThrough{N}, make_tuple(DynamicPassThrough{N},
DynamicPassThrough{C}, DynamicPassThrough{C},
DynamicLeftPad{Hi, InLeftPadH}, DynamicRightPad{Hi + InLeftPadH, InRightPadH},
DynamicLeftPad{Wi, InLeftPadW}), DynamicRightPad{Wi + InLeftPadW, InRightPadW}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
make_tuple(DynamicPassThrough{N},
DynamicPassThrough{C},
DynamicRightPad{Hi + InLeftPadH, InRightPadH},
DynamicRightPad{Wi + InLeftPadW, InRightPadW}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
#endif #endif
MultiIndex<4> idx; MultiIndex<4> idx;
...@@ -251,18 +249,39 @@ struct DummyDynamicTransform_v2 ...@@ -251,18 +249,39 @@ struct DummyDynamicTransform_v2
#endif #endif
} }
template <typename WeiDesc, typename InDesc, typename OutDesc, typename TransformInDesc> template <typename WeiDesc, typename InDesc, typename OutDesc>
__device__ void Run_3(index_t* const __restrict__ p_wei_global, __device__ void Run(index_t* const __restrict__ p_wei_global,
float* const __restrict__ p_in_global, float* const __restrict__ p_in_global,
float* const __restrict__ p_out_global, float* const __restrict__ p_out_global,
const WeiDesc /* wei_k_c_y_x_global_desc */, const WeiDesc wei_k_c_y_x_global_desc,
const InDesc /* in_n_c_hi_wi_global_desc */, const InDesc in_n_c_hi_wi_global_desc,
const OutDesc /* out_n_k_ho_wo_global_desc */, const OutDesc out_n_k_ho_wo_global_desc,
const TransformInDesc in_gemmk_gemmn_global_desc, const Array<index_t, 2> conv_strides,
const Array<index_t, 2> conv_strides, const Array<index_t, 2> conv_dilations,
const Array<index_t, 2> conv_dilations, const Array<index_t, 2> in_left_pads,
const Array<index_t, 2> in_left_pads, const Array<index_t, 2> in_right_pads) const
const Array<index_t, 2> in_right_pads) const {
Run_1(p_wei_global,
p_in_global,
p_out_global,
wei_k_c_y_x_global_desc,
in_n_c_hi_wi_global_desc,
out_n_k_ho_wo_global_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads);
}
};
template <index_t BlockSize>
struct DummyDynamicTransform_v2_2
{
template <typename TransformInDesc>
__device__ void Run(index_t* const __restrict__ p_wei_global,
float* const __restrict__ p_in_global,
float* const __restrict__ p_out_global,
const TransformInDesc in_gemmk_gemmn_global_desc) const
{ {
MultiIndex<2> idx; MultiIndex<2> idx;
...@@ -309,32 +328,6 @@ struct DummyDynamicTransform_v2 ...@@ -309,32 +328,6 @@ struct DummyDynamicTransform_v2
p_out_global[in_gemmk_gemmn_global_desc.CalculateOffset(idx)] = 1; p_out_global[in_gemmk_gemmn_global_desc.CalculateOffset(idx)] = 1;
#endif #endif
} }
template <typename WeiDesc, typename InDesc, typename OutDesc, typename TransformInDesc>
__device__ void Run(index_t* const __restrict__ p_wei_global,
float* const __restrict__ p_in_global,
float* const __restrict__ p_out_global,
const WeiDesc wei_k_c_y_x_global_desc,
const InDesc in_n_c_hi_wi_global_desc,
const OutDesc out_n_k_ho_wo_global_desc,
const TransformInDesc in_gemmk_gemmn_global_desc,
const Array<index_t, 2> conv_strides,
const Array<index_t, 2> conv_dilations,
const Array<index_t, 2> in_left_pads,
const Array<index_t, 2> in_right_pads) const
{
Run_1(p_wei_global,
p_in_global,
p_out_global,
wei_k_c_y_x_global_desc,
in_n_c_hi_wi_global_desc,
out_n_k_ho_wo_global_desc,
in_gemmk_gemmn_global_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads);
}
}; };
} // namespace ck } // namespace ck
......
...@@ -50,6 +50,26 @@ void device_dummy_dynamic_transform_v2(InDesc, ...@@ -50,6 +50,26 @@ void device_dummy_dynamic_transform_v2(InDesc,
const auto in_gemmk_gemmn_global_desc = tensor_descs.At(Number<0>{}); const auto in_gemmk_gemmn_global_desc = tensor_descs.At(Number<0>{});
// test on cpu
{
auto in_gemmk_gemmn_coord =
make_dynamic_tensor_coordinate_v2(in_gemmk_gemmn_global_desc, MultiIndex<2>{{0, 0}});
const auto in_gemmk_gemmn_coord_step = make_dynamic_tensor_coordinate_step_v2(
in_gemmk_gemmn_global_desc, MultiIndex<2>{{1, 0}});
for(index_t iter = 0; iter < 10; ++iter)
{
printf("iter %d\n", iter);
print_array("idx: ", in_gemmk_gemmn_coord.GetIndex());
printf("offset: %d\n", in_gemmk_gemmn_coord.GetOffset());
printf("\n");
move_dynamic_tensor_coordinate_v2(
in_gemmk_gemmn_global_desc, in_gemmk_gemmn_coord, in_gemmk_gemmn_coord_step);
}
}
std::size_t data_sz = sizeof(T); std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace()); DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
...@@ -64,8 +84,6 @@ void device_dummy_dynamic_transform_v2(InDesc, ...@@ -64,8 +84,6 @@ void device_dummy_dynamic_transform_v2(InDesc,
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
using dummy_transform = DummyDynamicTransform_v2<BlockSize>;
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
std::cout << "Start running " << nrepeat << " times..." << std::endl; std::cout << "Start running " << nrepeat << " times..." << std::endl;
...@@ -75,14 +93,14 @@ void device_dummy_dynamic_transform_v2(InDesc, ...@@ -75,14 +93,14 @@ void device_dummy_dynamic_transform_v2(InDesc,
for(index_t j = 0; j < nrepeat; ++j) for(index_t j = 0; j < nrepeat; ++j)
{ {
launch_kernel(run_gridwise_operation<dummy_transform, #if 1
launch_kernel(run_gridwise_operation<DummyDynamicTransform_v2_1<BlockSize>,
index_t* const, index_t* const,
float* const, float* const,
float* const, float* const,
const decltype(wei_kcyx_desc), const decltype(wei_kcyx_desc),
const decltype(in_nchw_desc), const decltype(in_nchw_desc),
const decltype(out_nkhw_desc), const decltype(out_nkhw_desc),
const decltype(in_gemmk_gemmn_global_desc),
const Array<index_t, 2>, const Array<index_t, 2>,
const Array<index_t, 2>, const Array<index_t, 2>,
const Array<index_t, 2>, const Array<index_t, 2>,
...@@ -97,11 +115,33 @@ void device_dummy_dynamic_transform_v2(InDesc, ...@@ -97,11 +115,33 @@ void device_dummy_dynamic_transform_v2(InDesc,
wei_kcyx_desc, wei_kcyx_desc,
in_nchw_desc, in_nchw_desc,
out_nkhw_desc, out_nkhw_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads);
#else
launch_kernel(run_gridwise_operation<DummyDynamicTransform_v2_2<BlockSize>,
index_t* const,
float* const,
float* const,
const decltype(in_gemmk_gemmn_global_desc),
const Array<index_t, 2>,
const Array<index_t, 2>,
const Array<index_t, 2>,
const Array<index_t, 2>>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<index_t*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<float*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<float*>(out_nkhw_device_buf.GetDeviceBuffer()),
in_gemmk_gemmn_global_desc, in_gemmk_gemmn_global_desc,
conv_strides, conv_strides,
conv_dilations, conv_dilations,
in_left_pads, in_left_pads,
in_right_pads); in_right_pads);
#endif
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment