"...composable_kernel.git" did not exist on "70eebf22533fa420c1351c28d0e1390b1b4bd33a"
Commit 55e420ec authored by rocking's avatar rocking
Browse files

Arrange the code

parent e67b958b
...@@ -131,6 +131,12 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp ...@@ -131,6 +131,12 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
const index_t MRaw = N * DTildeSlice * HTildeSlice * WTildeSlice * C;
const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw;
const index_t KRaw = ZDotSlice * YDotSlice * XDotSlice;
const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw;
// Out[ReduceM, ReduceK] // Out[ReduceM, ReduceK]
const auto out_n_dop_hop_wop_c_grid_desc = transform_tensor_descriptor( const auto out_n_dop_hop_wop_c_grid_desc = transform_tensor_descriptor(
out_n_do_ho_wo_c_grid_desc, out_n_do_ho_wo_c_grid_desc,
...@@ -199,12 +205,6 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp ...@@ -199,12 +205,6 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5>{}), make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const index_t MRaw = N * DTildeSlice * HTildeSlice * WTildeSlice * C;
const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw;
const index_t KRaw = ZDotSlice * YDotSlice * XDotSlice;
const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw;
const auto out_grid_desc_reducem_reducek = transform_tensor_descriptor( const auto out_grid_desc_reducem_reducek = transform_tensor_descriptor(
out_grid_desc_reducemraw_reducekraw, out_grid_desc_reducemraw_reducekraw,
make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)), make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment