Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ff4f8ba8
Commit
ff4f8ba8
authored
Jun 13, 2022
by
Chao Liu
Browse files
refactoring; add readme
parent
25e35b59
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
316 additions
and
101 deletions
+316
-101
example/03_gemm_bias_add_fastgelu/README.md
example/03_gemm_bias_add_fastgelu/README.md
+22
-0
example/03_gemm_bias_add_fastgelu/gemm_bias_add_fastgelu_xdl_fp16.cpp
...emm_bias_add_fastgelu/gemm_bias_add_fastgelu_xdl_fp16.cpp
+1
-7
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
...ation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
+179
-0
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
+0
-40
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+20
-6
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp
...ration/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp
+47
-32
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+2
-0
include/ck/utility/tuple.hpp
include/ck/utility/tuple.hpp
+45
-16
No files found.
example/03_gemm_bias_add_fastgelu/README.md
0 → 100644
View file @
ff4f8ba8
# Instructions for ```example_gemm_bias_add_fastgelu_xdl_fp16```
## Run ```example_gemm_bias_add_fastgelu_xdl_fp16```
```
bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: arg3: time kernel (0=no, 1=yes)
./bin/example_gemm_bias_add_fastgelu_xdl_fp16 1 1 1
```
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
```
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
d0_m_n: dim 2, lengths {3840, 4096}, strides {0, 1}
d1_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
e_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1}
Warm up 1 time
Start running 10 times...
Perf: 1.26914 ms, 101.525 TFlops, 100.804 GB/s, DeviceGemmMultipleD_Xdl_CShuffle<256, 256, 128, 32, 8, 8>
```
example/03_gemm_bias_add_fastgelu/gemm_bias_add_fastgelu_xdl_fp16.cpp
View file @
ff4f8ba8
...
@@ -113,7 +113,7 @@ int main(int argc, char* argv[])
...
@@ -113,7 +113,7 @@ int main(int argc, char* argv[])
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=n
0
, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=n
o
, 1=yes)
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE
\n
"
);
printf
(
"arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
...
@@ -161,12 +161,6 @@ int main(int argc, char* argv[])
...
@@ -161,12 +161,6 @@ int main(int argc, char* argv[])
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
EDataType
>
{
0.0
,
1.0
});
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
EDataType
>
{
0.0
,
1.0
});
}
}
std
::
cout
<<
"a: "
<<
a_m_k
.
mDesc
.
GetElementSpace
()
<<
std
::
endl
;
std
::
cout
<<
"b: "
<<
b_k_n
.
mDesc
.
GetElementSpace
()
<<
std
::
endl
;
std
::
cout
<<
"d0: "
<<
d0_m_n
.
mDesc
.
GetElementSpace
()
<<
std
::
endl
;
std
::
cout
<<
"d1: "
<<
d1_m_n
.
mDesc
.
GetElementSpace
()
<<
std
::
endl
;
std
::
cout
<<
"e: "
<<
e_m_n_device_result
.
mDesc
.
GetElementSpace
()
<<
std
::
endl
;
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
d0_m_n_device_buf
(
sizeof
(
D0DataType
)
*
d0_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
d0_m_n_device_buf
(
sizeof
(
D0DataType
)
*
d0_m_n
.
mDesc
.
GetElementSpace
());
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
0 → 100644
View file @
ff4f8ba8
#pragma once
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v7.hpp"
namespace
ck
{
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
typename
ThreadGroup
,
typename
ElementwiseOperation
,
InMemoryDataOperationEnum
DstInMemOp
,
typename
SliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
Src0Data
,
typename
Src1Data
,
typename
Src2Data
,
typename
DstData
,
typename
Src0Desc
,
typename
Src1Desc
,
typename
Src2Desc
,
typename
DstDesc
,
typename
DimAccessOrder
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
bool
ThreadTransferSrc0ResetCoordinateAfterRun
,
bool
ThreadTransferSrc1ResetCoordinateAfterRun
,
bool
ThreadTransferSrc2ResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
ThreadGroupTensorSliceTransfer_v7
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
Src0Desc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
SliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadGroupTensorSliceTransfer_v7
(
const
Src0Desc
&
src0_desc
,
const
Index
&
src0_block_slice_origin
,
const
Src1Desc
&
src1_desc
,
const
Index
&
src1_block_slice_origin
,
const
Src2Desc
&
src2_desc
,
const
Index
&
src2_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
,
const
ElementwiseOperation
&
element_op
)
:
threadwise_transfer_
(
tie
(
src0_desc
,
src1_desc
,
src2_desc
),
make_tuple
(
make_zero_multi_index
<
nDim
>
(),
make_zero_multi_index
<
nDim
>
(),
make_zero_multi_index
<
nDim
>
()),
tie
(
dst_desc
),
make_tuple
(
make_zero_multi_index
<
nDim
>
()),
element_op
)
{
static_assert
(
nDim
==
remove_cvref_t
<
Src0Desc
>::
GetNumOfDimension
()
&&
nDim
==
remove_cvref_t
<
Src1Desc
>::
GetNumOfDimension
()
&&
nDim
==
remove_cvref_t
<
Src2Desc
>::
GetNumOfDimension
()
&&
nDim
==
remove_cvref_t
<
DstDesc
>::
GetNumOfDimension
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
DimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
SliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! ThreadGroup::GetNumOfThread() too small"
);
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetSrcSliceOrigin
(
tie
(
src0_desc
,
src1_desc
,
src2_desc
),
make_tuple
(
src0_block_slice_origin
+
thread_data_idx_begin
,
src1_block_slice_origin
+
thread_data_idx_begin
,
src2_block_slice_origin
+
thread_data_idx_begin
));
threadwise_transfer_
.
SetDstSliceOrigin
(
tie
(
dst_desc
),
make_tuple
(
dst_block_slice_origin
+
thread_data_idx_begin
));
}
}
template
<
typename
Src0Buffer
,
typename
Src1Buffer
,
typename
Src2Buffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
Src0Desc
&
src0_desc
,
const
Src0Buffer
&
src0_buf
,
const
Src1Desc
&
src1_desc
,
const
Src1Buffer
&
src1_buf
,
const
Src2Desc
&
src2_desc
,
const
Src2Buffer
&
src2_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
Run
(
tie
(
src0_desc
,
src1_desc
,
src2_desc
),
tie
(
src0_buf
,
src1_buf
,
src2_buf
),
tie
(
dst_desc
),
tie
(
dst_buf
));
}
}
__device__
void
MoveSrc0SliceWindow
(
const
Src0Desc
&
src0_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrc0SliceWindow
(
src0_desc
,
step
);
}
}
__device__
void
MoveSrc1SliceWindow
(
const
Src1Desc
&
src1_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrc1SliceWindow
(
src1_desc
,
step
);
}
}
__device__
void
MoveSrc2SliceWindow
(
const
Src2Desc
&
src2_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrc2SliceWindow
(
src2_desc
,
step
);
}
}
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
}
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v7
<
Tuple
<
remove_cvref_t
<
Src0Data
>
,
remove_cvref_t
<
Src1Data
>
,
remove_cvref_t
<
Src2Data
>>
,
Tuple
<
remove_cvref_t
<
DstData
>>
,
Tuple
<
remove_reference_t
<
Src0Desc
>&
,
remove_reference_t
<
Src1Desc
>&
,
remove_reference_t
<
Src2Desc
>&>
,
Tuple
<
remove_reference_t
<
DstDesc
>&>
,
ElementwiseOperation
,
decltype
(
thread_slice_lengths
),
DimAccessOrder
,
VectorDim
,
ScalarPerVector
,
Sequence
<
ThreadTransferSrc0ResetCoordinateAfterRun
,
ThreadTransferSrc1ResetCoordinateAfterRun
,
ThreadTransferSrc2ResetCoordinateAfterRun
>
,
Sequence
<
ThreadTransferDstResetCoordinateAfterRun
>
,
DstInMemOp
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
ff4f8ba8
...
@@ -558,46 +558,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
...
@@ -558,46 +558,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
#if 1
{
std
::
cout
<<
"arg.a_grid_desc_ak0_m_ak1_{"
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_bk0_n_bk1_{"
<<
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.e_grid_desc_m_n_{ "
<<
arg
.
e_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
e_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_{ "
<<
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
I0
].
GetLength
(
I0
)
<<
", "
<<
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
I0
].
GetLength
(
I1
)
<<
", "
<<
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
I0
].
GetLength
(
I2
)
<<
", "
<<
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
I0
].
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_{ "
<<
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
I1
].
GetLength
(
I0
)
<<
", "
<<
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
I1
].
GetLength
(
I1
)
<<
", "
<<
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
I1
].
GetLength
(
I2
)
<<
", "
<<
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
I1
].
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"p_ds_grid{ "
<<
arg
.
p_ds_grid_
[
I0
]
<<
", "
<<
arg
.
p_ds_grid_
[
I1
]
<<
"}"
<<
std
::
endl
;
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
e_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
ff4f8ba8
...
@@ -7,8 +7,8 @@
...
@@ -7,8 +7,8 @@
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r3.hpp"
#include "thread_group_tensor_slice_transfer_v6r3.hpp"
#include "thread_group_tensor_slice_transfer_v7.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
...
@@ -223,7 +223,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -223,7 +223,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
return
e_grid_desc_mblock_mperblock_nblock_nperblock
;
return
e_grid_desc_mblock_mperblock_nblock_nperblock
;
}
}
// return block_id to
C
matrix tile idx (m0, n0) mapping
// return block_id to
E
matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
MakeDefaultBlock2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
{
...
@@ -579,7 +579,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -579,7 +579,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
cde_element_op
};
cde_element_op
};
#else
#else
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v
6r1
<
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v
7
<
ThisThreadBlock
,
// ThreadGroup
ThisThreadBlock
,
// ThreadGroup
CDEElementwiseOperation
,
// ElementwiseOperation,
CDEElementwiseOperation
,
// ElementwiseOperation,
EGlobalMemoryDataOperation
,
// DstInMemOp,
EGlobalMemoryDataOperation
,
// DstInMemOp,
...
@@ -588,18 +588,28 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -588,18 +588,28 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
1
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename Src0Data,
FloatCShuffle
,
// typename Src0Data,
FloatE
,
// typename DstData,
remove_cvref_t
<
decltype
(
DsDataType
{}[
I0
])
>
,
// typename Src1Data,
remove_cvref_t
<
decltype
(
DsDataType
{}[
I1
])
>
,
// typename Src2Data,
FloatE
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I0
]),
decltype
(
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I1
]),
decltype
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
3
,
// index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrc0ResetCoordinateAfterRun,
true
,
// bool ThreadTransferSrc0ResetCoordinateAfterRun,
false
,
// bool ThreadTransferSrc1ResetCoordinateAfterRun,
false
,
// bool ThreadTransferSrc2ResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
),
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I0
],
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I1
],
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
cde_element_op
};
cde_element_op
};
...
@@ -660,6 +670,10 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -660,6 +670,10 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
cde_block_copy_lds_and_global
.
Run
(
cde_block_copy_lds_and_global
.
Run
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_shuffle_block_buf
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I0
],
ds_grid_buf
[
I0
],
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I1
],
ds_grid_buf
[
I1
],
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_buf
);
e_grid_buf
);
#endif
#endif
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp
View file @
ff4f8ba8
...
@@ -28,12 +28,14 @@ template <typename SrcDatas,
...
@@ -28,12 +28,14 @@ template <typename SrcDatas,
typename
DimAccessOrder
,
typename
DimAccessOrder
,
index_t
VectorDim
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
index_t
ScalarPerVector
,
bool
SrcResetCoordinateAfterRun
,
typename
SrcResetCoordinateAfterRun
Flags
,
// Sequence<...>
bool
DstResetCoordinateAfterRun
,
typename
DstResetCoordinateAfterRun
Flags
,
// Sequence<...>
InMemoryDataOperationEnum
...
DstInMemOps
>
InMemoryDataOperationEnum
...
DstInMemOps
>
struct
ThreadwiseTensorSliceTransfer_v7
struct
ThreadwiseTensorSliceTransfer_v7
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
...
@@ -46,7 +48,7 @@ struct ThreadwiseTensorSliceTransfer_v7
...
@@ -46,7 +48,7 @@ struct ThreadwiseTensorSliceTransfer_v7
template
<
typename
Descs
,
template
<
typename
Descs
,
typename
Indices
,
typename
Indices
,
enable_if_t
<
Descs
::
Size
()
==
Indices
::
Size
(),
bool
>
=
false
>
enable_if_t
<
Descs
::
Size
()
==
Indices
::
Size
(),
bool
>
=
false
>
constexpr
auto
MakeCoordi
a
ntes
(
const
Descs
&
descs
,
const
Indices
&
indices
)
static
constexpr
auto
MakeCoordin
a
tes
(
const
Descs
&
descs
,
const
Indices
&
indices
)
{
{
return
generate_tuple
([
&
](
auto
i
)
{
return
make_tensor_coordinate
(
descs
[
i
],
indices
[
i
]);
},
return
generate_tuple
([
&
](
auto
i
)
{
return
make_tensor_coordinate
(
descs
[
i
],
indices
[
i
]);
},
Number
<
Descs
::
Size
()
>
{});
Number
<
Descs
::
Size
()
>
{});
...
@@ -95,22 +97,28 @@ struct ThreadwiseTensorSliceTransfer_v7
...
@@ -95,22 +97,28 @@ struct ThreadwiseTensorSliceTransfer_v7
});
});
}
}
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
template
<
typename
SrcBuffers
,
template
<
typename
SrcBuffers
,
typename
DstBuffers
,
typename
DstBuffers
,
enable_if_t
<
SrcDescs
::
Size
()
==
SrcBuffers
::
Size
()
&&
enable_if_t
<
SrcDescs
::
Size
()
==
SrcBuffers
::
Size
()
&&
DstDescs
::
Size
()
==
DstBuffers
::
Size
()
>
,
DstDescs
::
Size
()
==
DstBuffers
::
Size
(),
bool
=
false
>
bool
>
=
false
>
__device__
void
Run
(
const
SrcDescs
&
src_descs
,
__device__
void
Run
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
,
const
SrcBuffers
&
src_bufs
,
const
DstDescs
&
dst_descs
,
const
DstDescs
&
dst_descs
,
DstBuffers
&
dst_bufs
)
const
DstBuffers
&
dst_bufs
)
{
{
auto
generate_vectors
=
[
&
](
auto
data_types
)
{
auto
generate_vectors
=
[
&
](
auto
data_types
)
{
constexpr
index_t
num
=
data_types
.
Size
();
return
generate_tuple
([
&
](
auto
i
)
{
return
generate_tuple
([
&
](
auto
i
)
{
using
DataType
=
decltype
(
data_types
[
i
]);
using
DataType
=
remove_cvref_t
<
decltype
(
data_types
[
i
])
>
;
return
vector_type_maker_t
<
DataType
,
ScalarPerVector
>
{};
return
vector_type_maker_t
<
DataType
,
ScalarPerVector
>
{};
});
},
Number
<
num
>
{
});
};
};
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
...
@@ -122,7 +130,7 @@ struct ThreadwiseTensorSliceTransfer_v7
...
@@ -122,7 +130,7 @@ struct ThreadwiseTensorSliceTransfer_v7
// copy data from src_bufs into src_vectors
// copy data from src_bufs into src_vectors
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
using
src_vector_t
=
typename
remove_cv_t
<
decltype
(
src_vectors
[
i
])
>
::
type
;
using
src_vector_t
=
remove_cvref_t
<
typename
decltype
(
src_vectors
[
i
])
::
type
>
;
const
bool
is_src_valid
=
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_descs
[
i
],
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_descs
[
i
],
...
@@ -135,11 +143,16 @@ struct ThreadwiseTensorSliceTransfer_v7
...
@@ -135,11 +143,16 @@ struct ThreadwiseTensorSliceTransfer_v7
// apply pointwise function
// apply pointwise function
// FIXME: support tuple of arbitary size
// FIXME: support tuple of arbitary size
static_for
<
0
,
ScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
ScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
using
SrcData0
=
decltype
(
SrcDatas
{}.
At
[
I0
]);
using
SrcData0
=
remove_cvref_t
<
decltype
(
SrcDatas
{}[
I0
])
>
;
using
DstData0
=
decltype
(
DstDatas
{}.
At
[
I0
]);
using
SrcData1
=
remove_cvref_t
<
decltype
(
SrcDatas
{}[
I1
])
>
;
using
SrcData2
=
remove_cvref_t
<
decltype
(
SrcDatas
{}[
I2
])
>
;
using
DstData0
=
remove_cvref_t
<
decltype
(
DstDatas
{}[
I0
])
>
;
element_op_
(
dst_vectors
[
I0
].
template
AsType
<
DstData0
>()(
i
),
element_op_
(
dst_vectors
[
I0
].
template
AsType
<
DstData0
>()(
i
),
src_vectors
[
I0
].
template
AsType
<
SrcData0
>()[
i
]);
src_vectors
[
I0
].
template
AsType
<
SrcData0
>()[
i
],
src_vectors
[
I1
].
template
AsType
<
SrcData1
>()[
i
],
src_vectors
[
I2
].
template
AsType
<
SrcData2
>()[
i
]);
});
});
// copy data from buf_vectors into dst_bufs
// copy data from buf_vectors into dst_bufs
...
@@ -178,25 +191,25 @@ struct ThreadwiseTensorSliceTransfer_v7
...
@@ -178,25 +191,25 @@ struct ThreadwiseTensorSliceTransfer_v7
});
});
// move coordinate back to slice origin (or not)
// move coordinate back to slice origin (or not)
if
constexpr
(
SrcResetCoordinateAfterRun
)
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
{
if
constexpr
(
SrcResetCoordinateAfterRunFlags
::
At
(
i
))
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
{
const
auto
src_reset_step
=
const
auto
src_reset_step
=
make_tensor_coordinate_step
(
src_descs
[
i
],
GetCoordinateResetStep
());
make_tensor_coordinate_step
(
src_descs
[
i
],
GetCoordinateResetStep
());
move_tensor_coordinate
(
src_descs
[
i
],
src_coords_
(
i
),
src_reset_step
);
move_tensor_coordinate
(
src_descs
[
i
],
src_coords_
(
i
),
src_reset_step
);
}
);
}
}
}
);
if
constexpr
(
DstResetCoordinateAfterRun
)
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
{
if
constexpr
(
DstResetCoordinateAfterRunFlags
::
At
(
i
))
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
{
const
auto
dst_reset_step
=
const
auto
dst_reset_step
=
make_tensor_coordinate_step
(
dst_descs
[
i
],
GetCoordinateResetStep
());
make_tensor_coordinate_step
(
dst_descs
[
i
],
GetCoordinateResetStep
());
move_tensor_coordinate
(
dst_descs
[
i
],
dst_coords_
(
i
),
dst_reset_step
);
move_tensor_coordinate
(
dst_descs
[
i
],
dst_coords_
(
i
),
dst_reset_step
);
}
);
}
}
}
);
}
}
__device__
static
constexpr
auto
GetCoordinateResetStep
()
__device__
static
constexpr
auto
GetCoordinateResetStep
()
...
@@ -220,12 +233,13 @@ struct ThreadwiseTensorSliceTransfer_v7
...
@@ -220,12 +233,13 @@ struct ThreadwiseTensorSliceTransfer_v7
__device__
void
MoveSrcSliceWindow
(
const
SrcDescs
&
src_descs
,
__device__
void
MoveSrcSliceWindow
(
const
SrcDescs
&
src_descs
,
const
Index
&
src_slice_origin_step_idx
)
const
Index
&
src_slice_origin_step_idx
)
{
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const
auto
adjusted_step_idx
=
SrcResetCoordinateAfterRun
?
src_slice_origin_step_idx
:
src_slice_origin_step_idx
+
GetCoordinateResetStep
();
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const
auto
adjusted_step_idx
=
SrcResetCoordinateAfterRunFlags
::
At
(
i
)
?
src_slice_origin_step_idx
:
src_slice_origin_step_idx
+
GetCoordinateResetStep
();
// is it OK to construct a new step every time?
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
src_descs
[
i
],
adjusted_step_idx
);
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
src_descs
[
i
],
adjusted_step_idx
);
...
@@ -237,12 +251,13 @@ struct ThreadwiseTensorSliceTransfer_v7
...
@@ -237,12 +251,13 @@ struct ThreadwiseTensorSliceTransfer_v7
__device__
void
MoveDstSliceWindow
(
const
DstDescs
&
dst_descs
,
__device__
void
MoveDstSliceWindow
(
const
DstDescs
&
dst_descs
,
const
Index
&
dst_slice_origin_step_idx
)
const
Index
&
dst_slice_origin_step_idx
)
{
{
// if dst coord was not reset by Run(), then need to adjust the step here
const
auto
adjusted_step_idx
=
DstResetCoordinateAfterRun
?
dst_slice_origin_step_idx
:
dst_slice_origin_step_idx
+
GetCoordinateResetStep
();
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
// if dst coord was not reset by Run(), then need to adjust the step here
const
auto
adjusted_step_idx
=
DstResetCoordinateAfterRunFlags
::
At
(
i
)
?
dst_slice_origin_step_idx
:
dst_slice_origin_step_idx
+
GetCoordinateResetStep
();
// is it OK to construct a new step every time?
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
dst_descs
[
i
],
adjusted_step_idx
);
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
dst_descs
[
i
],
adjusted_step_idx
);
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
ff4f8ba8
...
@@ -6,6 +6,8 @@ namespace ck {
...
@@ -6,6 +6,8 @@ namespace ck {
template
<
typename
T
>
template
<
typename
T
>
union
BufferResource
union
BufferResource
{
{
__device__
constexpr
BufferResource
()
:
content
{}
{}
// 128 bit SGPRs to supply buffer resource in buffer instructions
// 128 bit SGPRs to supply buffer resource in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
int32x4_t
content
;
int32x4_t
content
;
...
...
include/ck/utility/tuple.hpp
View file @
ff4f8ba8
...
@@ -17,14 +17,18 @@ struct TupleElementKey
...
@@ -17,14 +17,18 @@ struct TupleElementKey
};
};
template
<
typename
Key
,
typename
Data
>
template
<
typename
Key
,
typename
Data
>
struct
TupleElement
struct
TupleElement
KeyData
{
{
__host__
__device__
constexpr
TupleElement
()
=
default
;
#if 0
__host__ __device__ constexpr TupleElementKeyData() = default;
#else
__host__
__device__
constexpr
TupleElementKeyData
()
:
mData
{}
{}
#endif
template
<
template
<
typename
T
,
typename
T
,
typename
enable_if
<!
is_same
<
remove_cvref_t
<
T
>,
TupleElement
>::
value
,
bool
>::
type
=
false
>
typename
enable_if
<!
is_same
<
remove_cvref_t
<
T
>,
TupleElement
KeyData
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleElement
(
T
&&
v
)
:
mData
(
std
::
forward
<
T
>
(
v
))
__host__
__device__
constexpr
TupleElement
KeyData
(
T
&&
v
)
:
mData
(
std
::
forward
<
T
>
(
v
))
{
{
}
}
...
@@ -32,20 +36,20 @@ struct TupleElement
...
@@ -32,20 +36,20 @@ struct TupleElement
};
};
template
<
typename
Key
,
typename
Data
>
template
<
typename
Key
,
typename
Data
>
__host__
__device__
constexpr
const
Data
&
get_tuple_element
(
const
TupleElement
<
Key
,
Data
>&
x
)
__host__
__device__
constexpr
const
Data
&
get_tuple_element
_data
(
const
TupleElement
KeyData
<
Key
,
Data
>&
x
)
{
{
return
static_cast
<
const
Data
&>
(
x
.
mData
);
return
static_cast
<
const
Data
&>
(
x
.
mData
);
}
}
template
<
typename
Key
,
typename
Data
>
template
<
typename
Key
,
typename
Data
>
__host__
__device__
constexpr
Data
&
get_tuple_element
(
TupleElement
<
Key
,
Data
>&
x
)
__host__
__device__
constexpr
Data
&
get_tuple_element
_data
(
TupleElement
KeyData
<
Key
,
Data
>&
x
)
{
{
return
x
.
mData
;
return
x
.
mData
;
}
}
// TODO: not sure the use of reference is correct
// TODO: not sure the use of reference is correct
template
<
typename
Key
,
typename
Data
>
template
<
typename
Key
,
typename
Data
>
__host__
__device__
constexpr
Data
&&
get_tuple_element
(
TupleElement
<
Key
,
Data
>&&
x
)
__host__
__device__
constexpr
Data
&&
get_tuple_element
_data
(
TupleElement
KeyData
<
Key
,
Data
>&&
x
)
{
{
return
static_cast
<
Data
&&>
(
x
.
mData
);
return
static_cast
<
Data
&&>
(
x
.
mData
);
}
}
...
@@ -54,7 +58,7 @@ template <typename Indices, typename... Xs>
...
@@ -54,7 +58,7 @@ template <typename Indices, typename... Xs>
struct
TupleImpl
;
struct
TupleImpl
;
template
<
index_t
...
Is
,
typename
...
Xs
>
template
<
index_t
...
Is
,
typename
...
Xs
>
struct
TupleImpl
<
Sequence
<
Is
...
>
,
Xs
...
>
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
...
struct
TupleImpl
<
Sequence
<
Is
...
>
,
Xs
...
>
:
TupleElement
KeyData
<
TupleElementKey
<
Is
>
,
Xs
>
...
{
{
__host__
__device__
constexpr
TupleImpl
()
=
default
;
__host__
__device__
constexpr
TupleImpl
()
=
default
;
...
@@ -63,13 +67,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
...
@@ -63,13 +67,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
!
is_same
<
remove_cvref_t
<
Y
>,
TupleImpl
>::
value
,
!
is_same
<
remove_cvref_t
<
Y
>,
TupleImpl
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleImpl
(
Y
&&
y
)
__host__
__device__
constexpr
TupleImpl
(
Y
&&
y
)
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Y
>
(
y
))...
:
TupleElement
KeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Y
>
(
y
))...
{
{
}
}
template
<
typename
...
Ys
,
typename
enable_if
<
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
template
<
typename
...
Ys
,
typename
enable_if
<
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleImpl
(
Ys
&&
...
ys
)
__host__
__device__
constexpr
TupleImpl
(
Ys
&&
...
ys
)
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Ys
>
(
ys
))...
:
TupleElement
KeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Ys
>
(
ys
))...
{
{
static_assert
(
sizeof
...(
Is
)
==
sizeof
...(
Xs
)
&&
sizeof
...(
Is
)
==
sizeof
...(
Ys
),
static_assert
(
sizeof
...(
Is
)
==
sizeof
...(
Xs
)
&&
sizeof
...(
Is
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
"wrong! inconsistent size"
);
...
@@ -78,15 +82,15 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
...
@@ -78,15 +82,15 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
__host__
__device__
static
constexpr
index_t
Size
()
{
return
sizeof
...(
Xs
);
}
__host__
__device__
static
constexpr
index_t
Size
()
{
return
sizeof
...(
Xs
);
}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
const
auto
&
GetElementByKey
(
TupleElementKey
<
I
>
)
const
__host__
__device__
constexpr
const
auto
&
GetElement
Data
ByKey
(
TupleElementKey
<
I
>
)
const
{
{
return
get_tuple_element
<
TupleElementKey
<
I
>>
(
*
this
);
return
get_tuple_element
_data
<
TupleElementKey
<
I
>>
(
*
this
);
}
}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
auto
&
GetElementByKey
(
TupleElementKey
<
I
>
)
__host__
__device__
constexpr
auto
&
GetElement
Data
ByKey
(
TupleElementKey
<
I
>
)
{
{
return
get_tuple_element
<
TupleElementKey
<
I
>>
(
*
this
);
return
get_tuple_element
_data
<
TupleElementKey
<
I
>>
(
*
this
);
}
}
};
};
...
@@ -121,7 +125,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
...
@@ -121,7 +125,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__
__device__
constexpr
const
auto
&
At
(
Number
<
I
>
)
const
__host__
__device__
constexpr
const
auto
&
At
(
Number
<
I
>
)
const
{
{
static_assert
(
I
<
base
::
Size
(),
"wrong! out of range"
);
static_assert
(
I
<
base
::
Size
(),
"wrong! out of range"
);
return
base
::
GetElementByKey
(
detail
::
TupleElementKey
<
I
>
{});
return
base
::
GetElement
Data
ByKey
(
detail
::
TupleElementKey
<
I
>
{});
}
}
// write access
// write access
...
@@ -129,7 +133,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
...
@@ -129,7 +133,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__
__device__
constexpr
auto
&
At
(
Number
<
I
>
)
__host__
__device__
constexpr
auto
&
At
(
Number
<
I
>
)
{
{
static_assert
(
I
<
base
::
Size
(),
"wrong! out of range"
);
static_assert
(
I
<
base
::
Size
(),
"wrong! out of range"
);
return
base
::
GetElementByKey
(
detail
::
TupleElementKey
<
I
>
{});
return
base
::
GetElement
Data
ByKey
(
detail
::
TupleElementKey
<
I
>
{});
}
}
// read access
// read access
...
@@ -159,6 +163,31 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
...
@@ -159,6 +163,31 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
};
};
template
<
>
struct
Tuple
<>
{
__host__
__device__
constexpr
Tuple
()
=
default
;
__host__
__device__
static
constexpr
index_t
Size
()
{
return
0
;
}
template
<
typename
T
>
__host__
__device__
constexpr
auto
operator
=
(
const
T
&
)
{
return
*
this
;
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
};
template
<
index_t
I
,
typename
TTuple
>
struct
tuple_element
{
using
type
=
decltype
(
TTuple
{}.
At
(
Number
<
I
>
{}));
};
template
<
index_t
I
,
typename
TTuple
>
using
tuple_element_t
=
typename
tuple_element
<
I
,
TTuple
>::
type
;
template
<
typename
...
Xs
>
template
<
typename
...
Xs
>
__host__
__device__
constexpr
auto
make_tuple
(
Xs
&&
...
xs
)
__host__
__device__
constexpr
auto
make_tuple
(
Xs
&&
...
xs
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment