Commit 638d3f02 authored by Adam Osewski's avatar Adam Osewski
Browse files

Fix loading Ds tensors.

parent e7aad72f
...@@ -1303,14 +1303,14 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1303,14 +1303,14 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
}); });
// TODO: on MI300 we could use NonTemporal load, MI200 streaming mode? // TODO: on MI300 we could use NonTemporal load, MI200 streaming mode?
auto ds_grid_buf = generate_tuple( const auto ds_grid_buf = generate_tuple(
[&](auto i) { [&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>( return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i], ds_grid_desc_m0m1_n0n1n2[i].GetElementSpaceSize()); p_ds_grid[i], ds_grid_desc_m0m1_n0n1n2[i].GetElementSpaceSize());
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
constexpr auto ds_thread_buf = generate_tuple( auto ds_thread_buf = generate_tuple(
[&](auto i) { [&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>; using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return StaticBuffer<AddressSpaceEnum::Vgpr, DDataType, ScalarPerVector, true>{}; return StaticBuffer<AddressSpaceEnum::Vgpr, DDataType, ScalarPerVector, true>{};
...@@ -1353,7 +1353,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1353,7 +1353,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
using SliceLengths = Sequence<I1, I1, I1, I1, ScalarPerVector>; using SliceLengths = Sequence<I1, I1, I1, I1, ScalarPerVector>;
return ThreadwiseTensorSliceTransfer_v2<DDataType, return ThreadwiseTensorSliceTransfer_v2<DDataType,
DDataType, DDataType,
decltype(ds_grid_desc_m0m1_n0n1n2(i)), decltype(ds_grid_desc_m0m1_n0n1n2[i]),
decltype(d_vgpr_buf_desc), decltype(d_vgpr_buf_desc),
SliceLengths, SliceLengths,
Sequence<0, 1, 2, 3, 4>, Sequence<0, 1, 2, 3, 4>,
...@@ -1361,7 +1361,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1361,7 +1361,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
ScalarPerVector, ScalarPerVector,
1, 1,
false>{ false>{
ds_grid_desc_m_n(i), ds_grid_desc_m0m1_n0n1n2[i],
make_multi_index( make_multi_index(
block_work_idx[I0], block_work_idx[I0],
thread_m_cluster_id * workspace_thread_desc_m0m1_n0n1n2.GetLength(I1), thread_m_cluster_id * workspace_thread_desc_m0m1_n0n1n2.GetLength(I1),
...@@ -1410,8 +1410,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1410,8 +1410,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
static_for<0, NIter, 1>{}([&](auto n_idx) { static_for<0, NIter, 1>{}([&](auto n_idx) {
// load multiple Ds: // load multiple Ds:
static_for<0, NumDTensor, 1>{}([&](auto d_idx) { static_for<0, NumDTensor, 1>{}([&](auto d_idx) {
ds_grid_load(d_idx).Run(ds_grid_desc_m0m1_n0n1n2(d_idx), ds_grid_load(d_idx).Run(ds_grid_desc_m0m1_n0n1n2[d_idx],
ds_grid_buf(d_idx), ds_grid_buf[d_idx],
d_vgpr_buf_desc, d_vgpr_buf_desc,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
ds_thread_buf(d_idx)); ds_thread_buf(d_idx));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment