Commit e2bfa9bb authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Fix wrong output descriptor for 2nd blockwise copy

parent 29053edd
......@@ -133,9 +133,7 @@ struct GridwiseCopy
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MPerThread>{}));
static constexpr auto I2 = Number<2>{};
using PassThroughOp = tensor_operation::element_wise::PassThrough;
......@@ -209,19 +207,6 @@ struct GridwiseCopy
// const index_t loop_step = blockPerGrid * blockSize * MPerThread;
const auto loop_step_index = make_multi_index(1, 0, 0);
#if 0
auto in_global_load =
ThreadwiseTensorSliceTransfer_v2<InDataType,
InDataType,
decltype(in_grid_1d_desc),
decltype(thread_buffer_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
InScalarPerVector, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{in_grid_1d_desc, thread_global_offset};
#else
const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
......@@ -292,21 +277,29 @@ struct GridwiseCopy
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
#endif
auto in_grid_1d_desc_tranformed = transform_tensor_descriptor(
in_grid_1d_desc,
make_tuple(make_pass_through_transform(in_grid_1d_desc.GetLength(I0)),
make_pass_through_transform(in_grid_1d_desc.GetLength(I1)),
make_pass_through_transform(in_grid_1d_desc.GetLength(I2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1>{}));
auto out_global_store = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
ElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<1, WPerBlock, HPerBlock>, // SliceLengths
Sequence<1, HPerBlock, WPerBlock>, // SliceLengths
ABlockTransferThreadClusterLengths,
Sequence<0, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
InDataType,
OutDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(out_grid_1d_desc),
decltype(in_grid_1d_desc_tranformed),
Sequence<0, 1, 2>, // ABlockTransferSrcAccessOrder
Sequence<0, 2, 1>, // ABlockTransferDstAccessOrder
Sequence<0, 1, 2>, // ABlockTransferDstAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferDstVectorDim
1, // ABlockTransferSrcScalarPerVector
......@@ -317,8 +310,8 @@ struct GridwiseCopy
true>(a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{},
out_grid_1d_desc,
make_multi_index(0, w_block_data_idx_on_grid, h_block_data_idx_on_grid),
in_grid_1d_desc_tranformed,
make_multi_index(0, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
elementwise_op);
index_t num_iter = in_grid_1d_desc.GetLength(I0);
......@@ -329,10 +322,13 @@ struct GridwiseCopy
in_global_load.MoveSrcSliceWindow(in_grid_1d_desc, loop_step_index);
out_global_store.Run(
a_block_desc_ak0_m_ak1, a_block_buf, out_grid_1d_desc, out_global_buf, I0);
out_global_store.Run(a_block_desc_ak0_m_ak1,
a_block_buf,
in_grid_1d_desc_tranformed,
out_global_buf,
I0);
out_global_store.MoveDstSliceWindow(out_grid_1d_desc, loop_step_index);
out_global_store.MoveDstSliceWindow(in_grid_1d_desc_tranformed, loop_step_index);
} while(--num_iter);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment