Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
638d3f02
Commit
638d3f02
authored
Jun 06, 2024
by
Adam Osewski
Browse files
Fix loading Ds tensors.
parent
e7aad72f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
6 deletions
+6
-6
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
.../grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
+6
-6
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
View file @
638d3f02
...
...
@@ -1303,14 +1303,14 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
});
// TODO: on MI300 we could use NonTemporal load, MI200 streaming mode?
auto
ds_grid_buf
=
generate_tuple
(
const
auto
ds_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ds_grid
[
i
],
ds_grid_desc_m0m1_n0n1n2
[
i
].
GetElementSpaceSize
());
},
Number
<
NumDTensor
>
{});
constexpr
auto
ds_thread_buf
=
generate_tuple
(
auto
ds_thread_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DDataType
,
ScalarPerVector
,
true
>
{};
...
...
@@ -1353,7 +1353,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
using
SliceLengths
=
Sequence
<
I1
,
I1
,
I1
,
I1
,
ScalarPerVector
>
;
return
ThreadwiseTensorSliceTransfer_v2
<
DDataType
,
DDataType
,
decltype
(
ds_grid_desc_m0m1_n0n1n2
(
i
)
),
decltype
(
ds_grid_desc_m0m1_n0n1n2
[
i
]
),
decltype
(
d_vgpr_buf_desc
),
SliceLengths
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
...
...
@@ -1361,7 +1361,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
ScalarPerVector
,
1
,
false
>
{
ds_grid_desc_m
_n
(
i
)
,
ds_grid_desc_m
0m1_n0n1n2
[
i
]
,
make_multi_index
(
block_work_idx
[
I0
],
thread_m_cluster_id
*
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I1
),
...
...
@@ -1410,8 +1410,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
static_for
<
0
,
NIter
,
1
>
{}([
&
](
auto
n_idx
)
{
// load multiple Ds:
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
d_idx
)
{
ds_grid_load
(
d_idx
).
Run
(
ds_grid_desc_m0m1_n0n1n2
(
d_idx
)
,
ds_grid_buf
(
d_idx
)
,
ds_grid_load
(
d_idx
).
Run
(
ds_grid_desc_m0m1_n0n1n2
[
d_idx
]
,
ds_grid_buf
[
d_idx
]
,
d_vgpr_buf_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
ds_thread_buf
(
d_idx
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment