Commit 638d3f02 authored by Adam Osewski's avatar Adam Osewski
Browse files

Fix loading Ds tensors.

parent e7aad72f
......@@ -1303,14 +1303,14 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
});
// TODO: on MI300 we could use NonTemporal load, MI200 streaming mode?
auto ds_grid_buf = generate_tuple(
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i], ds_grid_desc_m0m1_n0n1n2[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
constexpr auto ds_thread_buf = generate_tuple(
auto ds_thread_buf = generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return StaticBuffer<AddressSpaceEnum::Vgpr, DDataType, ScalarPerVector, true>{};
......@@ -1353,7 +1353,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
using SliceLengths = Sequence<I1, I1, I1, I1, ScalarPerVector>;
return ThreadwiseTensorSliceTransfer_v2<DDataType,
DDataType,
decltype(ds_grid_desc_m0m1_n0n1n2(i)),
decltype(ds_grid_desc_m0m1_n0n1n2[i]),
decltype(d_vgpr_buf_desc),
SliceLengths,
Sequence<0, 1, 2, 3, 4>,
......@@ -1361,7 +1361,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
ScalarPerVector,
1,
false>{
ds_grid_desc_m_n(i),
ds_grid_desc_m0m1_n0n1n2[i],
make_multi_index(
block_work_idx[I0],
thread_m_cluster_id * workspace_thread_desc_m0m1_n0n1n2.GetLength(I1),
......@@ -1410,8 +1410,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
static_for<0, NIter, 1>{}([&](auto n_idx) {
// load multiple Ds:
static_for<0, NumDTensor, 1>{}([&](auto d_idx) {
ds_grid_load(d_idx).Run(ds_grid_desc_m0m1_n0n1n2(d_idx),
ds_grid_buf(d_idx),
ds_grid_load(d_idx).Run(ds_grid_desc_m0m1_n0n1n2[d_idx],
ds_grid_buf[d_idx],
d_vgpr_buf_desc,
make_tuple(I0, I0, I0, I0, I0),
ds_thread_buf(d_idx));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment