Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2488d0bf
"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "54e1251e05261ffb369fbc898d7478d24dd10e69"
Commit
2488d0bf
authored
Jun 13, 2022
by
Chao Liu
Browse files
refactor
parent
97ec23bf
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
56 additions
and
33 deletions
+56
-33
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+33
-20
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp
...ration/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp
+4
-7
include/ck/utility/sequence.hpp
include/ck/utility/sequence.hpp
+7
-3
include/ck/utility/tuple_helper.hpp
include/ck/utility/tuple_helper.hpp
+12
-3
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
2488d0bf
...
@@ -541,21 +541,33 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -541,21 +541,33 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
n_thread_data_on_block_idx
[
I2
]),
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// shuffle: blockwise copy C from LDS to global
// tuple of reference to C/Ds tensor descriptors
// FIXME: arbitrary # of D tensors
const
auto
c_ds_desc_refs
=
concat_tuple_of_reference
(
const
auto
c_ds_descs
=
tie
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
tie
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I0
],
generate_tie
(
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I1
]);
[
&
](
auto
i
)
->
const
auto
&
// return type should be reference
{
return
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
i
];
},
Number
<
NumDTensor
>
{}));
// tuple of starting index of C/Ds blockwise copy
const
auto
idx_c_ds_block_begin
=
container_concat
(
make_tuple
(
make_multi_index
(
0
,
0
,
0
,
0
)),
generate_tuple
(
[
&
](
auto
)
{
return
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
);
},
Number
<
NumDTensor
>
{}));
// blockwise copy C/D/E between LDS and global
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7
<
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7
<
ThisThreadBlock
,
// ThreadGroup
ThisThreadBlock
,
Tuple
<
FloatCShuffle
,
Tuple
<
FloatCShuffle
,
remove_cvref_t
<
tuple_element_t
<
0
,
DsDataType
>>
,
remove_cvref_t
<
tuple_element_t
<
0
,
DsDataType
>>
,
remove_cvref_t
<
tuple_element_t
<
1
,
DsDataType
>>>
,
remove_cvref_t
<
tuple_element_t
<
1
,
DsDataType
>>>
,
Tuple
<
FloatE
>
,
// typename DstData,
Tuple
<
FloatE
>
,
decltype
(
c_ds_descs
),
decltype
(
c_ds_desc
_ref
s
),
decltype
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
)),
decltype
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
)),
CDEElementwiseOperation
,
// ElementwiseOperation,
CDEElementwiseOperation
,
Sequence
<
static_cast
<
index_t
>
(
EGlobalMemoryDataOperation
)
>
,
// FIXME: make Sequence
Sequence
<
static_cast
<
index_t
>
(
EGlobalMemoryDataOperation
)
>
,
// FIXME: make Sequence
// support arbitray type
// support arbitray type
Sequence
<
1
,
Sequence
<
1
,
...
@@ -566,13 +578,14 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -566,13 +578,14 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
3
,
// index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
Sequence
<
true
,
false
,
false
>
,
// ThreadTransferSrcResetCoordinateAfterRunFlags
sequence_merge_t
<
Sequence
<
false
>>
// ThreadTransferDstResetCoordinateAfterRunFlags
Sequence
<
true
>
,
{
c_ds_descs
,
uniform_sequence_gen_t
<
NumDTensor
,
make_tuple
(
make_multi_index
(
0
,
0
,
0
,
0
),
false
>>
,
// ThreadTransferSrcResetCoordinateAfterRunFlags
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
Sequence
<
false
>>
// ThreadTransferDstResetCoordinateAfterRunFlags
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
)),
{
c_ds_desc_refs
,
idx_c_ds_block_begin
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
make_tuple
(
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
)),
make_tuple
(
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
)),
cde_element_op
};
cde_element_op
};
...
@@ -619,7 +632,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -619,7 +632,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
// each block copy its data from LDS to global
// each block copy its data from LDS to global
cde_block_copy_lds_and_global
.
Run
(
cde_block_copy_lds_and_global
.
Run
(
c_ds_descs
,
c_ds_desc
_ref
s
,
tie
(
c_shuffle_block_buf
,
ds_grid_buf
[
I0
],
ds_grid_buf
[
I1
]),
tie
(
c_shuffle_block_buf
,
ds_grid_buf
[
I0
],
ds_grid_buf
[
I1
]),
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
e_grid_buf
));
tie
(
e_grid_buf
));
...
@@ -630,9 +643,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -630,9 +643,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
sfc_cde_block
.
GetForwardStep
(
access_id
);
sfc_cde_block
.
GetForwardStep
(
access_id
);
// move on Ds
// move on Ds
static_for
<
0
,
DsDataType
::
Size
()
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
cde_block_copy_lds_and_global
.
MoveSrcSliceWindow
(
cde_block_copy_lds_and_global
.
MoveSrcSliceWindow
(
c_ds_descs
,
i
+
I1
,
cde_lds_and_global_step
);
c_ds_desc
_ref
s
,
i
+
I1
,
cde_lds_and_global_step
);
});
});
// move on E
// move on E
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp
View file @
2488d0bf
...
@@ -7,6 +7,10 @@
...
@@ -7,6 +7,10 @@
namespace
ck
{
namespace
ck
{
// Assume:
// 1. src_descs and dst_descs are not known at compile-time
// 2. SrcBuffers and DstBuffers are DynamicBuffer
// 3. src_slice_origins and dst_slice_origins are not known at compile-time,
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions:
// and sometimes useless instructions:
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
...
@@ -14,11 +18,6 @@ namespace ck {
...
@@ -14,11 +18,6 @@ namespace ck {
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
// tensor coordinate instead
// tensor coordinate instead
// 3. Don't use a pointer to VGPR buffer, use vector instead
// 3. Don't use a pointer to VGPR buffer, use vector instead
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
template
<
typename
SrcDatas
,
template
<
typename
SrcDatas
,
typename
DstDatas
,
typename
DstDatas
,
typename
SrcDescs
,
typename
SrcDescs
,
...
@@ -34,8 +33,6 @@ template <typename SrcDatas,
...
@@ -34,8 +33,6 @@ template <typename SrcDatas,
struct
ThreadwiseTensorSliceTransfer_v7
struct
ThreadwiseTensorSliceTransfer_v7
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
...
...
include/ck/utility/sequence.hpp
View file @
2488d0bf
#ifndef CK_SEQUENCE_HPP
#pragma once
#define CK_SEQUENCE_HPP
#include "integral_constant.hpp"
#include "integral_constant.hpp"
#include "type.hpp"
#include "type.hpp"
...
@@ -882,5 +881,10 @@ __host__ __device__ constexpr bool sequence_all_of(Seq, F f)
...
@@ -882,5 +881,10 @@ __host__ __device__ constexpr bool sequence_all_of(Seq, F f)
return
flag
;
return
flag
;
}
}
template
<
typename
Sx
,
typename
Sy
>
using
sequence_merge_t
=
typename
sequence_merge
<
Sx
,
Sy
>::
type
;
template
<
index_t
NSize
,
index_t
I
>
using
uniform_sequence_gen_t
=
typename
uniform_sequence_gen
<
NSize
,
I
>::
type
;
}
// namespace ck
}
// namespace ck
#endif
include/ck/utility/tuple_helper.hpp
View file @
2488d0bf
#ifndef CK_TUPLE_HELPER_HPP
#pragma once
#define CK_TUPLE_HELPER_HPP
#include "functional4.hpp"
#include "functional4.hpp"
#include "tuple.hpp"
#include "tuple.hpp"
...
@@ -20,6 +19,17 @@ __host__ __device__ constexpr auto generate_tie(F&& f, Number<N>)
...
@@ -20,6 +19,17 @@ __host__ __device__ constexpr auto generate_tie(F&& f, Number<N>)
typename
arithmetic_sequence_gen
<
0
,
N
,
1
>::
type
{});
typename
arithmetic_sequence_gen
<
0
,
N
,
1
>::
type
{});
}
}
// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
template
<
typename
...
X
,
typename
...
Y
>
__host__
__device__
constexpr
auto
concat_tuple_of_reference
(
const
Tuple
<
X
&
...
>&
tx
,
const
Tuple
<
Y
&
...
>&
ty
)
{
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
tx
,
ty
);
}
namespace
detail
{
namespace
detail
{
template
<
typename
F
,
typename
X
,
index_t
...
Is
>
template
<
typename
F
,
typename
X
,
index_t
...
Is
>
...
@@ -66,4 +76,3 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y,
...
@@ -66,4 +76,3 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y,
}
}
}
// namespace ck
}
// namespace ck
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment