Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ad11d2a4
Commit
ad11d2a4
authored
Jun 14, 2022
by
Chao Liu
Browse files
fix
parent
2488d0bf
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
6 deletions
+20
-6
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
+3
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+10
-4
include/ck/utility/sequence.hpp
include/ck/utility/sequence.hpp
+7
-1
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
ad11d2a4
...
...
@@ -503,8 +503,10 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
static_assert
(
NumDTensor
==
0
,
"wrong!"
);
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
decltype
(
DsDataType
{}.
At
(
i
))
>
;
using
DDataType
=
tuple_element_t
<
i
.
value
,
DsDataType
>
;
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds_grid
[
i
]);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
ad11d2a4
...
...
@@ -549,6 +549,14 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
{
return
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
i
];
},
Number
<
NumDTensor
>
{}));
// tuple of reference to C/Ds tensor descriptors
const
auto
c_ds_buf_refs
=
concat_tuple_of_reference
(
tie
(
c_shuffle_block_buf
),
generate_tie
(
[
&
](
auto
i
)
->
const
auto
&
// return type should be reference
{
return
ds_grid_buf
[
i
];
},
Number
<
NumDTensor
>
{}));
// tuple of starting index of C/Ds blockwise copy
const
auto
idx_c_ds_block_begin
=
container_concat
(
make_tuple
(
make_multi_index
(
0
,
0
,
0
,
0
)),
...
...
@@ -561,9 +569,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
// blockwise copy C/D/E between LDS and global
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7
<
ThisThreadBlock
,
Tuple
<
FloatCShuffle
,
remove_cvref_t
<
tuple_element_t
<
0
,
DsDataType
>>
,
remove_cvref_t
<
tuple_element_t
<
1
,
DsDataType
>>>
,
decltype
(
container_concat
(
make_tuple
(
FloatCShuffle
{}),
DsDataType
{})),
Tuple
<
FloatE
>
,
decltype
(
c_ds_desc_refs
),
decltype
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
)),
...
...
@@ -633,7 +639,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
// each block copy its data from LDS to global
cde_block_copy_lds_and_global
.
Run
(
c_ds_desc_refs
,
tie
(
c_shuffle_block_buf
,
ds_grid_buf
[
I0
],
ds_grid_buf
[
I1
])
,
c_ds_buf_refs
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
e_grid_buf
));
...
...
include/ck/utility/sequence.hpp
View file @
ad11d2a4
...
...
@@ -240,7 +240,13 @@ struct arithmetic_sequence_gen
}
};
using
type
=
typename
sequence_gen
<
(
IEnd
-
IBegin
)
/
Increment
,
F
>::
type
;
using
type0
=
typename
sequence_gen
<
(
IEnd
-
IBegin
)
/
Increment
,
F
>::
type
;
using
type1
=
Sequence
<>
;
static
constexpr
bool
kHasContent
=
(
Increment
>
0
&&
IBegin
<
IEnd
)
||
(
Increment
<
0
&&
IBegin
>
IEnd
);
using
type
=
typename
conditional
<
kHasContent
,
type0
,
type1
>::
type
;
};
// uniform sequence
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment