Commit bc6513a2 authored by Chao Liu's avatar Chao Liu
Browse files

ThreadwiseTensorSliceTransfer_v3r2 support pointwise op on both src and dst

parent 583aab02
...@@ -507,18 +507,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6 ...@@ -507,18 +507,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r6
do do
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
a_block_slice_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1, b_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n_k1,
b_block_slice_copy_step);
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
a_grid_desc_k0_m_k1, a_grid_buf);
block_sync_lds(); block_sync_lds();
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(b_grid_desc_k0_n_k1, b_grid_buf);
b_grid_desc_k0_n_k1, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
......
...@@ -447,13 +447,24 @@ struct ThreadwiseTensorSliceTransfer_v3r2 ...@@ -447,13 +447,24 @@ struct ThreadwiseTensorSliceTransfer_v3r2
const bool is_dst_valid = const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
using dst_vector_t = typename vector_type_maker_t<DstData, DstScalarPerVector>::type; using dst_vector_type = vector_type_maker_t<DstData, DstScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
// copy data from dst_thread_scratch_ to dst_buf // copy data from dst_thread_scratch_ into dst_vector_container
auto dst_vector_container = dst_vector_type{
dst_thread_scratch_.template GetAsType<dst_vector_t>(dst_data_idx_seq)};
// apply DstElementwiseOperation on dst_vector_container
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
dst_vector_container.template AsType<DstData>()(i) =
dst_element_op_(dst_vector_container.template AsType<DstData>()[i]);
});
// copy data from dst_vector_container to dst_buf
dst_buf.template Set<dst_vector_t>( dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset(), dst_coord_.GetOffset(),
is_dst_valid, is_dst_valid,
dst_thread_scratch_.template GetAsType<dst_vector_t>(dst_data_idx_seq)); dst_vector_container.template AsType<dst_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment