Commit ad77ce8e authored by letaoqin's avatar letaoqin
Browse files

fix for passthrough element op

parent f820c621
...@@ -544,115 +544,116 @@ struct GridwiseGemmDlMultipleD_km_kn_mn ...@@ -544,115 +544,116 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id()); get_thread_local_1d_id());
const auto ds_grid_buf = generate_tuple( if constexpr(!is_same_v<CDEElementwiseOperation,
[&](auto i) { ck::tensor_operation::element_wise::PassThrough>)
return make_dynamic_buffer<AddressSpaceEnum::Global>( {
p_ds_grid[i], ds_grid_desc_m0_m10_m11_n0_n10_n11[i].GetElementSpaceSize()); const auto ds_grid_buf = generate_tuple(
}, [&](auto i) {
Number<NumDTensor>{}); return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
auto ds_thread_buf = generate_tuple( ds_grid_desc_m0_m10_m11_n0_n10_n11[i].GetElementSpaceSize());
[&](auto i) { },
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>; Number<NumDTensor>{});
return StaticBuffer<AddressSpaceEnum::Vgpr, auto ds_thread_buf = generate_tuple(
DDataType, [&](auto i) {
c_m10_m11_n10_n11_thread_tensor_lengths[I3], using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
true>{};
}, return StaticBuffer<AddressSpaceEnum::Vgpr,
Number<NumDTensor>{}); DDataType,
c_m10_m11_n10_n11_thread_tensor_lengths[I3],
auto ds_threadwise_copy = generate_tuple( true>{};
[&](auto i) { },
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>; Number<NumDTensor>{});
return ThreadwiseTensorSliceTransfer_v2< auto ds_threadwise_copy = generate_tuple(
DDataType, [&](auto i) {
DDataType, using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
decltype(ds_grid_desc_m0_m10_m11_n0_n10_n11[i]),
decltype(c_thread_desc_m0_m10_m11_n0_n10_n11), return ThreadwiseTensorSliceTransfer_v2<
Sequence<I1, DDataType,
I1, DDataType,
I1, decltype(ds_grid_desc_m0_m10_m11_n0_n10_n11[i]),
I1, decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
I1, Sequence<I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}>, I1,
CThreadTransferSrcDstAccessOrder, I1,
CThreadTransferSrcDstVectorDim, I1,
CThreadTransferDstScalarPerVector, I1,
1, Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}>,
false>( CThreadTransferSrcDstAccessOrder,
ds_grid_desc_m0_m10_m11_n0_n10_n11[i], CThreadTransferSrcDstVectorDim,
make_multi_index( CThreadTransferDstScalarPerVector,
im0, 1,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], false>(
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
in0, make_multi_index(im0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])); // register number c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
}, in0,
Number<NumDTensor>{}); c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]));
static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I0], 1>{}([&](auto m10) { },
static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I1], 1>{}([&](auto m11) { Number<NumDTensor>{});
static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I2], 1>{}([&](auto n10) {
ignore = m10; static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I0], 1>{}([&](auto m10) {
ignore = m11; static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I1], 1>{}([&](auto m11) {
ignore = n10; static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I2], 1>{}(
ignore = ds_thread_buf; [&](auto n10) {
ignore = ds_threadwise_copy; // load d matrix data
ignore = ds_grid_buf; static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_threadwise_copy(i).Run(ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
static_for<0, NumDTensor, 1>{}([&](auto i) { ds_grid_buf[i],
ds_threadwise_copy(i).Run(ds_grid_desc_m0_m10_m11_n0_n10_n11[i], c_thread_desc_m0_m10_m11_n0_n10_n11,
ds_grid_buf[i], make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_desc_m0_m10_m11_n0_n10_n11, ds_thread_buf(i));
make_tuple(I0, I0, I0, I0, I0, I0), });
ds_thread_buf(i)); // cal element op
}); static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I3], 1>{}(
[&](auto i) {
static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I3], 1>{}( // get reference to src data
[&](auto i) { const auto src_data_refs = generate_tie(
// get reference to src data // return type should be lvalue
const auto src_data_refs = generate_tie( [&](auto iSrc) -> const auto& {
// return type should be lvalue return ds_thread_buf[iSrc][i];
[&](auto iSrc) -> const auto& { },
return ds_thread_buf[iSrc][i]; Number<NumDTensor>{});
},
Number<NumDTensor>{}); // get reference to dst data
constexpr index_t c_offset =
// get reference to dst data c_thread_desc_m0_m10_m11_n0_n10_n11.CalculateOffset(
constexpr index_t c_offset = make_tuple(0, m10, m11, 0, n10, i));
c_thread_desc_m0_m10_m11_n0_n10_n11.CalculateOffset( auto dst_data_refs = generate_tie(
make_tuple(0, m10, m11, 0, n10, i)); // return type should be lvalue
auto dst_data_refs = generate_tie( [&](auto) -> auto& {
// return type should be lvalue return c_thread_buf(Number<c_offset>{});
[&](auto) -> auto& { return c_thread_buf(Number<c_offset>{}); }, },
Number<2>{}); Number<2>{});
unpack2(cde_element_op, dst_data_refs, src_data_refs); unpack2(cde_element_op, dst_data_refs, src_data_refs);
});
static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_threadwise_copy(i).MoveSrcSliceWindow(
ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
make_multi_index(0, 0, 0, 0, 1, 0));
});
}); });
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_threadwise_copy(i).MoveSrcSliceWindow( ds_threadwise_copy(i).MoveSrcSliceWindow(
ds_grid_desc_m0_m10_m11_n0_n10_n11[i], ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
make_multi_index(0, 0, 0, 0, 1, 0)); make_multi_index(
0, 0, 1, 0, -c_m10_m11_n10_n11_thread_tensor_lengths[I2], 0));
}); });
}); });
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_threadwise_copy(i).MoveSrcSliceWindow( ds_threadwise_copy(i).MoveSrcSliceWindow(
ds_grid_desc_m0_m10_m11_n0_n10_n11[i], ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
make_multi_index( make_multi_index(
0, 0, 1, 0, -c_m10_m11_n10_n11_thread_tensor_lengths[I2], 0)); 0, 1, -c_m10_m11_n10_n11_thread_tensor_lengths[I1], 0, 0, 0));
}); });
}); });
static_for<0, NumDTensor, 1>{}([&](auto i) { }
ds_threadwise_copy(i).MoveSrcSliceWindow(
ds_grid_desc_m0_m10_m11_n0_n10_n11[i],
make_multi_index(
0, 1, -c_m10_m11_n10_n11_thread_tensor_lengths[I1], 0, 0, 0));
});
});
ThreadwiseTensorSliceTransfer_v1r3< ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc, FloatAcc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment