Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b097be17
Commit
b097be17
authored
Jun 23, 2022
by
root
Browse files
merge changes for upstream/latest update
parents
8a891bbd
a49115b9
Changes
140
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1850 additions
and
95 deletions
+1850
-95
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+668
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
...eration/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
+6
-5
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
...or_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
+3
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp
...k/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp
include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp
+407
-0
include/ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp
...nsor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp
+129
-0
include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp
...r_operation/gpu/thread/reduction_functions_threadwise.hpp
+12
-12
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp
...ration/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp
+295
-0
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+2
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+9
-0
include/ck/utility/enable_if.hpp
include/ck/utility/enable_if.hpp
+1
-3
include/ck/utility/math.hpp
include/ck/utility/math.hpp
+16
-0
include/ck/utility/reduction_functions_accumulate.hpp
include/ck/utility/reduction_functions_accumulate.hpp
+19
-0
include/ck/utility/reduction_operator.hpp
include/ck/utility/reduction_operator.hpp
+120
-27
include/ck/utility/sequence.hpp
include/ck/utility/sequence.hpp
+14
-4
include/ck/utility/tuple.hpp
include/ck/utility/tuple.hpp
+49
-21
include/ck/utility/tuple_helper.hpp
include/ck/utility/tuple_helper.hpp
+12
-3
library/include/ck/library/host_tensor/host_reduction.hpp
library/include/ck/library/host_tensor/host_reduction.hpp
+19
-14
library/include/ck/library/host_tensor/host_tensor.hpp
library/include/ck/library/host_tensor/host_tensor.hpp
+66
-2
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
0 → 100644
View file @
b097be17
#pragma once
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v7.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
namespace
ck
{
// input : A[AK0, M, AK1]
// input : B[AK0, N, AK1]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
template
<
typename
FloatAB
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
DsDataType
,
typename
FloatE
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
EGridDesc_M_N
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1Value
,
index_t
BK1Value
,
index_t
MPerXdl
,
index_t
NPerXdl
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
>
struct
GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
>
{},
I1
,
Number
<
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
{}));
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
}
// ck::Tuple<const D0DataType*, const D1DataType*, ...>
static
constexpr
auto
MakeDsGridPointer
()
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
return
static_cast
<
const
DDataType
*>
(
nullptr
);
},
Number
<
NumDTensor
>
{});
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
// LDS allocation for C shuffle in LDS
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
constexpr
auto
c_block_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
),
c_block_size
*
sizeof
(
FloatCShuffle
));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2ETileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2ETileMap
&
block_2_etile_map
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
);
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)))
return
false
;
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
return
false
;
// check gridwise gemm pipeline
const
auto
num_k_loop
=
K
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
return
false
;
}
if
(
!
block_2_etile_map
.
CheckValidity
(
e_grid_desc_m_n
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
__host__
__device__
static
constexpr
auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
const
auto
M
=
e_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
e_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
e_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
e_grid_desc_mblock_mperblock_nblock_nperblock
;
}
// return block_id to E matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
EGridDesc_M_N
>
(
e_grid_desc_m_n
);
}
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
using
DefaultBlock2ETileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
using
DsGridPointer
=
decltype
(
MakeDsGridPointer
());
template
<
bool
HasMainKBlockLoop
,
typename
Block2ETileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
DsGridPointer
p_ds_grid
,
FloatE
*
__restrict__
p_e_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
StaticallyIndexedArray
<
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>&
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
// FIXME: Ds desc may be of different
// type from E
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
&
block_2_etile_map
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
const
auto
ds_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ds_grid
[
i
],
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
i
].
GetElementSpaceSize
());
},
Number
<
NumDTensor
>
{});
auto
e_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_e_grid
,
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// divide block work by [M, N]
const
auto
block_work_idx
=
block_2_etile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_etile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
);
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumGemmKPrefetchStage
>
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumGemmKPrefetchStage
>
(
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
FloatAB
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
,
LoopSched
>
();
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1_Selector
<
NumGemmKPrefetchStage
,
LoopSched
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
num_k_block_main_loop
);
// shuffle C and write out
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatCShuffle
*>
(
p_shared
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
// M0 (MXdlPerWave) per shuffle
M1
,
// M1 = MWave
M2
,
// M2 * M3 * M4 = MPerXdl
M3
,
M4
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
// N0 (NXdlPerWave) per shuffle
N1
,
// N1 = NWave
N2
))),
// N2 = NPerXdl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
FloatCShuffle
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// tuple of reference to C/Ds tensor descriptors
const
auto
c_ds_desc_refs
=
concat_tuple_of_reference
(
tie
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
generate_tie
(
[
&
](
auto
i
)
->
const
auto
&
// return type should be reference
{
return
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
i
];
},
Number
<
NumDTensor
>
{}));
// tuple of reference to C/Ds tensor descriptors
const
auto
c_ds_buf_refs
=
concat_tuple_of_reference
(
tie
(
c_shuffle_block_buf
),
generate_tie
(
[
&
](
auto
i
)
->
const
auto
&
// return type should be reference
{
return
ds_grid_buf
[
i
];
},
Number
<
NumDTensor
>
{}));
// tuple of starting index of C/Ds blockwise copy
const
auto
idx_c_ds_block_begin
=
container_concat
(
make_tuple
(
make_multi_index
(
0
,
0
,
0
,
0
)),
generate_tuple
(
[
&
](
auto
)
{
return
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
);
},
Number
<
NumDTensor
>
{}));
// blockwise copy C/D/E between LDS and global
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7
<
ThisThreadBlock
,
decltype
(
container_concat
(
make_tuple
(
FloatCShuffle
{}),
DsDataType
{})),
Tuple
<
FloatE
>
,
decltype
(
c_ds_desc_refs
),
decltype
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
)),
CDEElementwiseOperation
,
Sequence
<
static_cast
<
index_t
>
(
EGlobalMemoryDataOperation
)
>
,
// FIXME: make Sequence
// support arbitray type
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
sequence_merge_t
<
Sequence
<
true
>
,
uniform_sequence_gen_t
<
NumDTensor
,
false
>>
,
// ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence
<
false
>>
// ThreadTransferDstResetCoordinateAfterRunFlags
{
c_ds_desc_refs
,
idx_c_ds_block_begin
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
make_tuple
(
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
)),
cde_element_op
};
// space filling curve for threadwise C in VGPR before shuffle
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
M2
,
1
,
M4
,
1
>>
{};
// space filling curve for shuffled blockwise C/D/E
constexpr
auto
sfc_cde_block
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_cde_block
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
// each block copy its data from LDS to global
cde_block_copy_lds_and_global
.
Run
(
c_ds_desc_refs
,
c_ds_buf_refs
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
e_grid_buf
));
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
cde_lds_and_global_step
=
sfc_cde_block
.
GetForwardStep
(
access_id
);
// move on Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
cde_block_copy_lds_and_global
.
MoveSrcSliceWindow
(
c_ds_desc_refs
,
i
+
I1
,
cde_lds_and_global_step
);
});
// move on E
cde_block_copy_lds_and_global
.
MoveDstSliceWindow
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
I0
,
cde_lds_and_global_step
);
}
});
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
View file @
b097be17
...
...
@@ -21,7 +21,7 @@ template <typename GridwiseGemm,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
,
typename
Dxs
Reduce
AccElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
...
@@ -41,7 +41,7 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
DxsInElementwiseOperation
dxs_in_element_op
,
const
DxsAccElementwiseOperation
dxs_out_element_op
,
const
Dxs
Reduce
AccElementwiseOperation
dxs_out_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -96,7 +96,7 @@ template <typename FloatAB,
typename
CElementwiseOperation
,
typename
DxsReduceOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsAccElementwiseOperation
,
typename
Dxs
Reduce
AccElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
DGlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
...
...
@@ -329,7 +329,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
DxsInElementwiseOperation
&
dxs_in_element_op
,
const
DxsAccElementwiseOperation
&
dxs_out_element_op
,
const
Dxs
Reduce
AccElementwiseOperation
&
dxs_out_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
...
...
@@ -816,7 +816,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
false
>
;
// Global write Gemm shuffle + reduction
const
auto
d_identityVal
=
DReduceOperation
::
GetIdentityValue
();
const
auto
d_identityVal
=
DReduceOperation
::
template
GetIdentityValue
<
FloatReduceAcc
>();
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
[
&
](
auto
I
)
{
d_thread_buf
(
I
)
=
d_identityVal
;
});
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
View file @
b097be17
...
...
@@ -791,8 +791,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
constexpr
auto
c_block_desc_mblock_mperblock_nblock_nperblock
=
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
void
*
p_shared
=
static_cast
<
void
*>
(
p_shared_block
);
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatC
*>
(
p_shared
_block
),
static_cast
<
FloatC
*>
(
p_shared
),
c_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
static_assert
(
M1
==
MWave
,
""
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
View file @
b097be17
...
...
@@ -249,7 +249,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
}();
using
BlockwiseGemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
ThisThread
Block
,
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
Block
Size
FloatAB
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
...
...
@@ -453,7 +453,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
// sanity check
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
ThisThread
Block
,
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
Block
Size
,
FloatAB
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp
View file @
b097be17
...
...
@@ -37,7 +37,7 @@ __global__ void kernel_buffer_set_value(const Grid1dBufferDescType grid_1d_buffe
{
using
PassThroughOp
=
tensor_operation
::
element_wise
::
UnaryIdentic
<
DataType
,
DataType
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp
0 → 100644
View file @
b097be17
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2022 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef GRIDWISE_SOFTMAX_HPP
#define GRIDWISE_SOFTMAX_HPP
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
#include "reduction_functions_blockwise.hpp"
#include "reduction_functions_threadwise.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseReduction
,
typename
InDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
GridDesc_M_K
>
__global__
void
kernel_softmax
(
const
GridDesc_M_K
in_grid_desc_m_k
,
const
GridDesc_M_K
out_grid_desc_m_k
,
index_t
block_group_size
,
index_t
num_k_block_tile_iteration
,
AccDataType
alpha
,
const
InDataType
*
const
__restrict__
p_in_value_global
,
AccDataType
beta
,
OutDataType
*
const
__restrict__
p_out_value_global
)
{
GridwiseReduction
::
Run
(
in_grid_desc_m_k
,
out_grid_desc_m_k
,
block_group_size
,
num_k_block_tile_iteration
,
alpha
,
p_in_value_global
,
beta
,
p_out_value_global
);
};
template
<
typename
InDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
GridDesc_M_K
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
struct
GridwiseSoftmax_mk_to_mk
{
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
))
&&
(
KThreadSliceSize
%
OutDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
InSrcVectorDim
==
0
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
BlockwiseMaxReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
reduce
::
Max
,
false
>
;
// PropagateNan
using
ThreadwiseMaxReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
reduce
::
Max
,
false
>
;
// PropagateNan
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
GridDesc_M_K
&
in_grid_desc_m_k
,
const
GridDesc_M_K
&
out_grid_desc_m_k
,
index_t
block_group_size
,
index_t
num_k_block_tile_iteration
,
AccDataType
alpha
,
const
InDataType
*
const
__restrict__
p_in_value_global
,
AccDataType
beta
,
OutDataType
*
const
__restrict__
p_out_value_global
)
{
// LDS
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
auto
out_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_value_global
,
out_grid_desc_m_k
.
GetElementSpaceSize
());
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_buffer
,
BlockSize
);
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
out_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
max_value_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
max_value_buf
(
I
)
=
reduce
::
Max
::
template
GetIdentityValue
<
AccDataType
>();
});
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
AccDataType
>();
});
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
block_group_size
;
const
index_t
block_local_id
=
block_global_id
%
block_group_size
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
const
index_t
reduceSizePerBlock
=
K_BlockTileSize
*
num_k_block_tile_iteration
;
using
ThreadBufferLengths
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_src_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
AccDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
InSrcVectorSize
,
1
,
false
>
(
in_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_dst_load
=
ThreadwiseTensorSliceTransfer_v2
<
OutDataType
,
AccDataType
,
GridDesc_M_K
,
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
InSrcVectorSize
,
1
,
false
>
(
out_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_dst_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
OutDataType
,
decltype
(
thread_buffer_desc
),
GridDesc_M_K
,
PassThroughOp
,
ThreadBufferLengths
,
ThreadBufferDimAccessOrder
,
InSrcVectorDim
,
OutDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
out_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
),
PassThroughOp
{});
constexpr
auto
in_thread_copy_fwd_step
=
make_multi_index
(
0
,
K_BlockTileSize
);
constexpr
auto
in_thread_copy_bwd_step
=
make_multi_index
(
0
,
-
K_BlockTileSize
);
///
/// max(x)
///
const
auto
in_global_val_buf_oob_non_zero
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
reduce
::
Max
::
template
GetIdentityValue
<
InDataType
>());
index_t
reducedTiles
=
0
;
do
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf_oob_non_zero
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
ThreadwiseMaxReduce
::
Reduce
(
in_thread_buf
,
max_value_buf
);
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_fwd_step
);
reducedTiles
++
;
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}(
[
&
](
auto
I
)
{
BlockwiseMaxReduce
::
Reduce
(
reduce_work_buf
,
max_value_buf
(
I
));
});
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_bwd_step
);
///
/// sum(exp(x - max(x)))
///
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
AccDataType
>();
});
// Normally, 0 as invalid element value is adequate since 0 makes no contribution to
// accumulated result. However, in stable softmax, all values 0s or not are subtracted by
// another value_max. As numbers become non-zero, effectively it allows invalid values to
// slip through and contribute to the accumulated result.
//
// The trick here is leveraging the fact that many math functions (add, sub, exp, ...)
// propagate NaNs when operands have NaNs involved. By initialiing invalid element value
// with NaN, an invalid value doing math manipulations is still NaN, which in turn can still
// be identified as an invalid value. We can then discard the invalid values which
// originally failed the bound check during accumulation. This allows to ignore values that
// failed bound check even after multiple math manipulations.
const
auto
in_global_val_buf_oob_nan
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_value_global
,
in_grid_desc_m_k
.
GetElementSpaceSize
(),
NumericLimits
<
InDataType
>::
QuietNaN
());
using
BlockwiseSumReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
reduce
::
Add
,
false
,
// ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Add
,
AccDataType
>>
;
using
ThreadwiseSumReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
reduce
::
Add
,
false
,
// ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Add
,
AccDataType
>>
;
reducedTiles
=
0
;
do
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf_oob_nan
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
// do element-wise pre-reduction operation
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
in_thread_buf
(
Number
<
offset
>
{})
=
math
::
exp
(
in_thread_buf
(
Number
<
offset
>
{})
-
max_value_buf
(
iM
));
});
});
ThreadwiseSumReduce
::
Reduce
(
in_thread_buf
,
accu_value_buf
);
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_bwd_step
);
reducedTiles
++
;
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
accu_value_buf
(
I
));
// block_sync_lds();
});
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_fwd_step
);
///
/// softmax
///
reducedTiles
=
0
;
if
(
float_equal_zero
{}(
beta
))
{
do
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf_oob_nan
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x)))
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
out_thread_buf
(
Number
<
offset
>
{})
=
alpha
*
math
::
exp
(
in_thread_buf
(
Number
<
offset
>
{})
-
max_value_buf
(
iM
))
/
accu_value_buf
(
iM
);
});
});
threadwise_dst_store
.
Run
(
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
out_thread_buf
,
out_grid_desc_m_k
,
out_global_val_buf
);
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_fwd_step
);
threadwise_dst_store
.
MoveDstSliceWindow
(
out_grid_desc_m_k
,
in_thread_copy_fwd_step
);
reducedTiles
++
;
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
}
else
{
do
{
threadwise_src_load
.
Run
(
in_grid_desc_m_k
,
in_global_val_buf_oob_nan
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in_thread_buf
);
threadwise_dst_load
.
Run
(
out_grid_desc_m_k
,
out_global_val_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
out_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
out_thread_buf
(
Number
<
offset
>
{})
=
alpha
*
math
::
exp
(
in_thread_buf
(
Number
<
offset
>
{})
-
max_value_buf
(
iM
))
/
accu_value_buf
(
iM
)
+
beta
*
out_thread_buf
(
Number
<
offset
>
{});
});
});
threadwise_dst_store
.
Run
(
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
out_thread_buf
,
out_grid_desc_m_k
,
out_global_val_buf
);
threadwise_src_load
.
MoveSrcSliceWindow
(
in_grid_desc_m_k
,
in_thread_copy_fwd_step
);
threadwise_dst_store
.
MoveDstSliceWindow
(
out_grid_desc_m_k
,
in_thread_copy_fwd_step
);
threadwise_dst_load
.
MoveSrcSliceWindow
(
out_grid_desc_m_k
,
in_thread_copy_fwd_step
);
reducedTiles
++
;
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
}
}
};
}
// namespace ck
#endif // GRIDWISE_SOFTMAX_HPP
include/ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp
0 → 100644
View file @
b097be17
#pragma once
#include "cluster_descriptor.hpp"
#include "data_type.hpp"
#include "element_wise_operation.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
template
<
typename
GridwiseUEltwise
,
typename
ADataType
,
typename
BDataType
,
typename
GridDesc_M0
,
typename
ElementwiseFunctor
>
__global__
void
kernel_unary_elementwise_1d
(
const
ADataType
*
__restrict__
p_a_global
,
BDataType
*
__restrict__
p_b_global
,
const
GridDesc_M0
a_grid_desc_m0
,
const
GridDesc_M0
b_grid_desc_m0
,
const
ElementwiseFunctor
functor
)
{
GridwiseUEltwise
::
Run
(
p_a_global
,
p_b_global
,
a_grid_desc_m0
,
b_grid_desc_m0
,
functor
);
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
GridDesc_M0
,
typename
ElementwiseFunctor
,
index_t
ScalarPerVector
>
struct
GridwiseUnaryElementwise_1D
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
thread_desc_m0
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
ScalarPerVector
>
{}));
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
static
__device__
auto
CalculateElementwiseIndex
()
{
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
return
make_multi_index
(
global_thread_id
*
ScalarPerVector
);
}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
GridDesc_M0
a_grid_desc_m0
,
const
GridDesc_M0
b_grid_desc_m0
)
{
return
a_grid_desc_m0
.
GetLength
(
I0
)
==
b_grid_desc_m0
.
GetLength
(
I0
);
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
index_t
tensor_size
)
{
const
index_t
grid_size
=
math
::
integer_divide_ceil
(
tensor_size
,
256
*
ScalarPerVector
);
return
grid_size
;
}
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_global
,
BDataType
*
__restrict__
p_b_global
,
const
GridDesc_M0
a_grid_desc_m0
,
const
GridDesc_M0
b_grid_desc_m0
,
const
ElementwiseFunctor
functor
)
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_global
,
a_grid_desc_m0
.
GetElementSpaceSize
());
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_global
,
b_grid_desc_m0
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
,
ScalarPerVector
,
true
>
a_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
BDataType
,
ScalarPerVector
,
true
>
b_thread_buf
;
const
auto
thread_store_global_offset
=
CalculateElementwiseIndex
();
auto
a_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
ADataType
,
ADataType
,
GridDesc_M0
,
decltype
(
thread_desc_m0
),
Sequence
<
ScalarPerVector
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
ScalarPerVector
,
1
,
// SrcScalarStrideInVector
false
>
{
a_grid_desc_m0
,
thread_store_global_offset
};
auto
b_global_write
=
ThreadwiseTensorSliceTransfer_v1r3
<
BDataType
,
BDataType
,
decltype
(
thread_desc_m0
),
GridDesc_M0
,
PassThrough
,
Sequence
<
ScalarPerVector
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// DstVectorDim
ScalarPerVector
,
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
false
>
{
b_grid_desc_m0
,
thread_store_global_offset
,
PassThrough
{}};
const
index_t
blockSize
=
get_block_size
();
const
index_t
blockPerGrid
=
get_grid_size
();
const
auto
m0
=
b_grid_desc_m0
.
GetLength
(
I0
);
const
index_t
loop_step
=
blockPerGrid
*
blockSize
*
ScalarPerVector
;
const
auto
loop_step_index
=
make_multi_index
(
loop_step
);
index_t
num_iter
=
m0
/
(
loop_step
);
do
{
// read and process ScalarPerVector elements
a_global_load
.
Run
(
a_grid_desc_m0
,
a_global_buf
,
thread_desc_m0
,
make_tuple
(
I0
),
a_thread_buf
);
static_for
<
0
,
ScalarPerVector
,
1
>
{}([
&
](
auto
m
)
{
constexpr
auto
offset
=
thread_desc_m0
.
CalculateOffset
(
make_tuple
(
m
));
functor
(
b_thread_buf
(
Number
<
offset
>
{}),
a_thread_buf
(
Number
<
offset
>
{}));
});
b_global_write
.
Run
(
thread_desc_m0
,
make_tuple
(
I0
),
// SrcSliceOriginIdx
b_thread_buf
,
b_grid_desc_m0
,
b_global_buf
);
a_global_load
.
MoveSrcSliceWindow
(
a_grid_desc_m0
,
loop_step_index
);
b_global_write
.
MoveDstSliceWindow
(
b_grid_desc_m0
,
loop_step_index
);
}
while
(
--
num_iter
);
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp
View file @
b097be17
...
...
@@ -39,7 +39,9 @@ template <typename AccDataType,
typename
SrcThreadDesc_M_K
,
typename
DstThreadDesc_M
,
typename
OpReduce
,
bool
PropagateNan
>
bool
PropagateNan
,
typename
Accumulation
=
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
OpReduce
,
AccDataType
>
>
struct
ThreadwiseReduction
{
static
constexpr
auto
src_thread_desc_m_k
=
SrcThreadDesc_M_K
{};
...
...
@@ -51,8 +53,6 @@ struct ThreadwiseReduction
static_assert
(
src_length_m
==
dst_length_m
,
"lengths of source and dst buffer must match!"
);
using
Accumulation
=
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
OpReduce
,
AccDataType
>
;
template
<
typename
SrcBufferType
,
typename
DstBufferType
>
__device__
static
void
Reduce
(
const
SrcBufferType
&
src_buf
,
DstBufferType
&
dst_buf
)
{
...
...
@@ -73,12 +73,15 @@ struct ThreadwiseReduction
// 2) DstDesc is known at compile-time
// 3) SrcBuffer is static buffer
// 4) DstBuffer is static buffer
template
<
typename
AccDataType
,
typename
IndexDataType
,
typename
SrcThreadDesc_M_K
,
typename
DstThreadDesc_M
,
typename
OpReduce
,
bool
PropagateNan
>
template
<
typename
AccDataType
,
typename
IndexDataType
,
typename
SrcThreadDesc_M_K
,
typename
DstThreadDesc_M
,
typename
OpReduce
,
bool
PropagateNan
,
typename
Accumulation
=
detail
::
AccumulateWithIndexAndNanCheck
<
PropagateNan
,
OpReduce
,
AccDataType
,
IndexDataType
>
>
struct
ThreadwiseReductionWithIndex
{
static
constexpr
auto
src_thread_desc_m_k
=
SrcThreadDesc_M_K
{};
...
...
@@ -90,9 +93,6 @@ struct ThreadwiseReductionWithIndex
static_assert
(
src_length_m
==
dst_length_m
,
"lengths of source and dst buffer must match!"
);
using
Accumulation
=
detail
::
AccumulateWithIndexAndNanCheck
<
PropagateNan
,
OpReduce
,
AccDataType
,
IndexDataType
>
;
template
<
typename
SrcValueBufferType
,
typename
SrcIndexBufferType
,
typename
DstValueBufferType
,
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp
0 → 100644
View file @
b097be17
#pragma once
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_space_filling_curve.hpp"
namespace
ck
{
// Thread-level multi-source, multi-destination tensor slice data movement
// Assume:
// 1. All sources and destinations are DynamicBuffer
// 2. Same VectorDim and ScalerPerVector for all sources and destinations
// 3. DstInMemOps are per destination tensor
// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
// 6. Does not need to know src_descs and dst_descs at compile-time
// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time,
//
// Does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer
// 2. Pass tensor descritpors by reference (or tuple of references)
// 3. Does not keep reference to tensor descriptor
// 4. Does not construct new tensor coordinate when call Run()
template
<
typename
SrcDatas
,
typename
DstDatas
,
typename
SrcDescs
,
typename
DstDescs
,
typename
ElementwiseOperation
,
typename
DstInMemOps
,
// Sequence<InMemoryDataOperationEnum ...>
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
typename
SrcResetCoordinateAfterRunFlags
,
// Sequence<bool ...>
typename
DstResetCoordinateAfterRunFlags
>
// Sequence<bool ...>
struct
ThreadwiseTensorSliceTransfer_v7
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
static
constexpr
index_t
nSrc
=
SrcDescs
::
Size
();
static
constexpr
index_t
nDst
=
DstDescs
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
// return a tuple of coordiantes for a tuple of tensor
template
<
typename
Descs
,
typename
Indices
,
enable_if_t
<
Descs
::
Size
()
==
Indices
::
Size
(),
bool
>
=
false
>
static
constexpr
auto
MakeCoordinates
(
const
Descs
&
descs
,
const
Indices
&
indices
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
make_tensor_coordinate
(
descs
[
i
],
indices
[
i
]);
},
Number
<
Descs
::
Size
()
>
{});
}
using
SrcCoords
=
decltype
(
MakeCoordinates
(
SrcDescs
{},
StaticallyIndexedArray
<
Index
,
nSrc
>
{}));
using
DstCoords
=
decltype
(
MakeCoordinates
(
DstDescs
{},
StaticallyIndexedArray
<
Index
,
nDst
>
{}));
// scalar per access on each dim
// FIXME: don't use lambda_scalar_per_access
static
constexpr
auto
scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
VectorDim
,
ScalarPerVector
>
{},
Number
<
nDim
>
{});
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
remove_cv_t
<
decltype
(
scalar_per_access
)
>>
;
__device__
constexpr
ThreadwiseTensorSliceTransfer_v7
(
const
SrcDescs
&
src_descs
,
const
StaticallyIndexedArray
<
Index
,
nSrc
>&
src_slice_origins
,
const
DstDescs
&
dst_descs
,
const
StaticallyIndexedArray
<
Index
,
nDst
>&
dst_slice_origins
,
const
ElementwiseOperation
&
element_op
)
:
src_coords_
(
MakeCoordinates
(
src_descs
,
src_slice_origins
)),
dst_coords_
(
MakeCoordinates
(
dst_descs
,
dst_slice_origins
)),
element_op_
(
element_op
)
{
static_assert
(
SliceLengths
::
At
(
Number
<
VectorDim
>
{})
%
ScalarPerVector
==
0
,
"wrong! cannot evenly divide"
);
}
template
<
typename
Indices
,
enable_if_t
<
SrcDescs
::
Size
()
==
Indices
::
Size
(),
bool
>
=
false
>
__device__
void
SetSrcSliceOrigins
(
const
SrcDescs
&
src_descs
,
const
Indices
&
src_slice_origin_idxs
)
{
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
src_coords_
(
i
)
=
make_tensor_coordinate
(
src_descs
[
i
],
src_slice_origin_idxs
[
i
]);
});
}
template
<
typename
Indices
,
enable_if_t
<
DstDescs
::
Size
()
==
Indices
::
Size
(),
bool
>
=
false
>
__device__
void
SetDstSliceOrigins
(
const
DstDescs
&
dst_descs
,
const
Indices
&
dst_slice_origin_idxs
)
{
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
dst_coords_
(
i
)
=
make_tensor_coordinate
(
dst_descs
[
i
],
dst_slice_origin_idxs
[
i
]);
});
}
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
template
<
typename
SrcBuffers
,
typename
DstBuffers
,
enable_if_t
<
SrcDescs
::
Size
()
==
SrcBuffers
::
Size
()
&&
DstDescs
::
Size
()
==
DstBuffers
::
Size
(),
bool
>
=
false
>
__device__
void
Run
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
,
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
)
{
auto
generate_vectors
=
[
&
](
auto
data_types
)
{
constexpr
index_t
num
=
data_types
.
Size
();
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
data_types
[
i
])
>
;
return
vector_type_maker_t
<
DataType
,
ScalarPerVector
>
{};
},
Number
<
num
>
{});
};
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
// loop over space-filling curve
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
iAccess
)
{
auto
src_vectors
=
generate_vectors
(
SrcDatas
{});
auto
dst_vectors
=
generate_vectors
(
DstDatas
{});
// copy data from src_bufs into src_vectors
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
using
src_vector_t
=
typename
remove_cvref_t
<
decltype
(
src_vectors
[
i
])
>::
type
;
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_descs
[
i
],
src_coords_
[
i
]);
src_vectors
(
i
).
template
AsType
<
src_vector_t
>()(
I0
)
=
src_bufs
[
i
].
template
Get
<
src_vector_t
>(
src_coords_
[
i
].
GetOffset
(),
is_src_valid
);
});
// apply pointwise function
static_for
<
0
,
ScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
// get reference to src data
const
auto
src_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iSrc
)
->
const
auto
&
{
using
SrcData
=
remove_cvref_t
<
tuple_element_t
<
iSrc
.
value
,
SrcDatas
>>
;
return
src_vectors
[
iSrc
].
template
AsType
<
SrcData
>()[
i
];
},
Number
<
nSrc
>
{});
// get reference to dst data
auto
dst_data_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iDst
)
->
auto
&
{
using
DstData
=
remove_cvref_t
<
tuple_element_t
<
iDst
.
value
,
DstDatas
>>
;
return
dst_vectors
(
iDst
).
template
AsType
<
DstData
>()(
i
);
},
Number
<
nDst
>
{});
// apply pointwise function
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2
(
element_op_
,
dst_data_refs
,
src_data_refs
);
});
// copy data from buf_vectors into dst_bufs
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
using
dst_vector_t
=
typename
remove_cvref_t
<
decltype
(
dst_vectors
[
i
])
>::
type
;
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_descs
[
i
],
dst_coords_
[
i
]);
constexpr
InMemoryDataOperationEnum
DstInMemOp
=
static_cast
<
InMemoryDataOperationEnum
>
(
DstInMemOps
::
At
(
i
.
value
));
dst_bufs
(
i
).
template
Update
<
DstInMemOp
,
dst_vector_t
>(
dst_coords_
[
i
].
GetOffset
(),
is_dst_valid
,
dst_vectors
[
i
].
template
AsType
<
dst_vector_t
>()[
I0
]);
});
// move coordinate
if
constexpr
(
iAccess
.
value
!=
num_access
-
1
)
{
constexpr
auto
forward_step
=
SpaceFillingCurve
::
GetForwardStep
(
iAccess
);
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
move_tensor_coordinate
(
src_descs
[
i
],
src_coords_
(
i
),
make_tensor_coordinate_step
(
src_descs
[
i
],
forward_step
));
});
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
move_tensor_coordinate
(
dst_descs
[
i
],
dst_coords_
(
i
),
make_tensor_coordinate_step
(
dst_descs
[
i
],
forward_step
));
});
}
});
// move coordinate back to slice origin (or not)
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
SrcResetCoordinateAfterRunFlags
::
At
(
i
))
{
const
auto
src_reset_step
=
make_tensor_coordinate_step
(
src_descs
[
i
],
GetCoordinateResetStep
());
move_tensor_coordinate
(
src_descs
[
i
],
src_coords_
(
i
),
src_reset_step
);
}
});
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
DstResetCoordinateAfterRunFlags
::
At
(
i
))
{
const
auto
dst_reset_step
=
make_tensor_coordinate_step
(
dst_descs
[
i
],
GetCoordinateResetStep
());
move_tensor_coordinate
(
dst_descs
[
i
],
dst_coords_
(
i
),
dst_reset_step
);
}
});
}
__device__
static
constexpr
auto
GetCoordinateResetStep
()
{
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
if
constexpr
(
num_access
==
0
)
{
return
typename
SpaceFillingCurve
::
Index
{};
}
else
{
constexpr
auto
reset_step
=
SpaceFillingCurve
::
GetStepBetween
(
Number
<
num_access
-
1
>
{},
Number
<
0
>
{});
return
reset_step
;
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template
<
index_t
ISrc
>
__device__
void
MoveSrcSliceWindow
(
const
SrcDescs
&
src_descs
,
Number
<
ISrc
>
iSrc
,
const
Index
&
src_slice_origin_step_idx
)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const
auto
adjusted_step_idx
=
SrcResetCoordinateAfterRunFlags
::
At
(
iSrc
)
?
src_slice_origin_step_idx
:
src_slice_origin_step_idx
+
GetCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
src_descs
[
iSrc
],
adjusted_step_idx
);
move_tensor_coordinate
(
src_descs
[
iSrc
],
src_coords_
(
iSrc
),
adjusted_step
);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
template
<
index_t
IDst
>
__device__
void
MoveDstSliceWindow
(
const
DstDescs
&
dst_descs
,
Number
<
IDst
>
iDst
,
const
Index
&
dst_slice_origin_step_idx
)
{
// if dst coord was not reset by Run(), then need to adjust the step here
const
auto
adjusted_step_idx
=
DstResetCoordinateAfterRunFlags
::
At
(
iDst
)
?
dst_slice_origin_step_idx
:
dst_slice_origin_step_idx
+
GetCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
dst_descs
[
iDst
],
adjusted_step_idx
);
move_tensor_coordinate
(
dst_descs
[
iDst
],
dst_coords_
(
iDst
),
adjusted_step
);
}
private:
SrcCoords
src_coords_
;
DstCoords
dst_coords_
;
const
ElementwiseOperation
element_op_
;
};
}
// namespace ck
include/ck/utility/amd_buffer_addressing.hpp
View file @
b097be17
...
...
@@ -6,6 +6,8 @@ namespace ck {
template
<
typename
T
>
union
BufferResource
{
__device__
constexpr
BufferResource
()
:
content
{}
{}
// 128 bit SGPRs to supply buffer resource in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
int32x4_t
content
;
...
...
include/ck/utility/data_type.hpp
View file @
b097be17
#pragma once
#include "statically_indexed_array.hpp"
namespace
ck
{
...
...
@@ -1000,6 +1001,11 @@ struct NumericLimits
__host__
__device__
static
constexpr
T
Max
()
{
return
std
::
numeric_limits
<
T
>::
max
();
}
__host__
__device__
static
constexpr
T
Lowest
()
{
return
std
::
numeric_limits
<
T
>::
lowest
();
}
__host__
__device__
static
constexpr
T
QuietNaN
()
{
return
std
::
numeric_limits
<
T
>::
quiet_NaN
();
}
};
template
<
>
...
...
@@ -1008,12 +1014,15 @@ struct NumericLimits<half_t>
static
constexpr
unsigned
short
binary_min
=
0x0400
;
static
constexpr
unsigned
short
binary_max
=
0x7BFF
;
static
constexpr
unsigned
short
binary_lowest
=
0xFBFF
;
static
constexpr
unsigned
short
binary_qnan
=
0x7FFF
;
__host__
__device__
static
constexpr
half_t
Min
()
{
return
bit_cast
<
half_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
half_t
Max
()
{
return
bit_cast
<
half_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
half_t
Lowest
()
{
return
bit_cast
<
half_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
half_t
QuietNaN
()
{
return
bit_cast
<
half_t
>
(
binary_qnan
);
}
};
}
// namespace ck
include/ck/utility/enable_if.hpp
View file @
b097be17
#ifndef CK_ENABLE_IF_HPP
#define CK_ENABLE_IF_HPP
#pragma once
namespace
ck
{
...
...
@@ -10,4 +9,3 @@ template <bool B, typename T = void>
using
enable_if_t
=
typename
std
::
enable_if
<
B
,
T
>::
type
;
}
// namespace ck
#endif
include/ck/utility/math.hpp
View file @
b097be17
...
...
@@ -142,6 +142,22 @@ __host__ __device__ constexpr auto min(X x, Ys... ys)
return
min
(
x
,
min
(
ys
...));
}
// disallow implicit type casting
template
<
typename
T
>
__device__
T
exp
(
T
x
);
template
<
>
__device__
float
exp
<
float
>
(
float
x
)
{
return
__expf
(
x
);
}
template
<
>
__device__
double
exp
<
double
>
(
double
x
)
{
return
exp
(
x
);
}
// greatest common divisor, aka highest common factor
__host__
__device__
constexpr
index_t
gcd
(
index_t
x
,
index_t
y
)
{
...
...
include/ck/utility/reduction_functions_accumulate.hpp
View file @
b097be17
...
...
@@ -35,9 +35,27 @@
namespace
ck
{
namespace
detail
{
// Check for NaN; guarantee NaNs are NOT propagated to result (i.e., ignore NaNs)
template
<
typename
ReduceOperation
,
typename
AccDataType
>
struct
AccumulateWithNanIgnore
{
__device__
static
inline
void
Calculate
(
AccDataType
&
accuVal
,
AccDataType
currVal
)
{
if
(
!
isnan
(
currVal
))
{
ReduceOperation
{}(
accuVal
,
currVal
);
}
};
};
template
<
bool
PropagateNan
,
typename
ReduceOperation
,
typename
AccDataType
>
struct
AccumulateWithNanCheck
;
// Does not check for NaN; does not guarantee NaNs be propagated to result
// e.g., given that max(a, b) = a > b ? a : b
// then max(NaN, 1) returns 1
// max(1, NaN) returns NaN
// since any comparison involving NaNs returns false
template
<
typename
ReduceOperation
,
typename
AccDataType
>
struct
AccumulateWithNanCheck
<
false
,
ReduceOperation
,
AccDataType
>
{
...
...
@@ -48,6 +66,7 @@ struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
};
};
// Check for NaN; guarantees NaNs be propagated to result
template
<
typename
ReduceOperation
,
typename
AccDataType
>
struct
AccumulateWithNanCheck
<
true
,
ReduceOperation
,
AccDataType
>
{
...
...
include/ck/utility/reduction_operator.hpp
View file @
b097be17
...
...
@@ -28,6 +28,7 @@
#include "config.hpp"
#include "data_type.hpp"
#include "type.hpp"
namespace
ck
{
...
...
@@ -54,64 +55,92 @@ namespace reduce {
// accumulated index also need be
// changed.
template
<
class
T
>
struct
Add
{
using
dataType
=
T
;
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
static_cast
<
T
>
(
0.0
f
);
};
template
<
typename
T
>
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
type_convert
<
T
>
(
0.0
f
);
};
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
return
operation
==
InMemoryDataOperationEnum
::
AtomicAdd
||
operation
==
InMemoryDataOperationEnum
::
Set
;
};
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
a
=
a
+
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"The data type is not supported by the Add accumulator!"
);
a
=
a
+
b
;
}
};
template
<
class
T
>
struct
Mul
{
using
dataType
=
T
;
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
static_cast
<
T
>
(
1.0
f
);
};
template
<
typename
T
>
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
type_convert
<
T
>
(
1.0
f
);
};
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
return
operation
==
InMemoryDataOperationEnum
::
Set
;
};
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
a
=
a
*
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"The data type is not supported by the Mul accumulator!"
);
a
=
a
*
b
;
}
};
template
<
class
T
>
struct
Max
{
using
dataType
=
T
;
template
<
typename
T
>
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
NumericLimits
<
T
>::
Lowest
();
};
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
// ToChange: atomic_max to be added
return
operation
==
InMemoryDataOperationEnum
::
Set
;
};
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Max accumulator!"
);
if
(
a
<
b
)
a
=
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Max accumulator!"
);
if
(
a
<
b
)
{
a
=
b
;
...
...
@@ -120,28 +149,41 @@ struct Max
}
};
template
<
class
T
>
struct
Min
{
using
dataType
=
T
;
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
NumericLimits
<
T
>::
Max
();
};
template
<
typename
T
>
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
NumericLimits
<
T
>::
Max
();
};
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
// ToChange: atomic_min to be added
return
operation
==
InMemoryDataOperationEnum
::
Set
;
};
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Min accumulator!"
);
if
(
a
>
b
)
a
=
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Min accumulator!"
);
if
(
a
>
b
)
{
a
=
b
;
...
...
@@ -150,28 +192,41 @@ struct Min
}
};
template
<
class
T
>
struct
AMax
{
using
dataType
=
T
;
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
static_cast
<
T
>
(
0.0
f
);
};
template
<
typename
T
>
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
type_convert
<
T
>
(
0.0
f
);
};
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
// ToChange: atomic_max to be added
return
operation
==
InMemoryDataOperationEnum
::
Set
;
};
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the AMax accumulator!"
);
if
(
a
<
b
)
a
=
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the AMax accumulator!"
);
if
(
a
<
b
)
{
a
=
b
;
...
...
@@ -181,7 +236,7 @@ struct AMax
};
template
<
typename
T
>
T
GetIdentityValue
ue
ForInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
constexpr
T
GetIdentityValueForInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
T
result
=
ck
::
type_convert
<
T
>
(
0.0
f
);
...
...
@@ -191,6 +246,44 @@ T GetIdentityValueueForInMemoryDataOperation(InMemoryDataOperationEnum operation
return
(
result
);
};
template
<
InMemoryDataOperationEnum
Operation
,
typename
DataType
>
struct
InMemoryDataOperatonSupportedOnDataType
{
static
constexpr
bool
value
=
false
;
};
template
<
typename
DataType
>
struct
InMemoryDataOperatonSupportedOnDataType
<
InMemoryDataOperationEnum
::
AtomicAdd
,
DataType
>
{
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
;
};
template
<
typename
DataType
>
struct
InMemoryDataOperatonSupportedOnDataType
<
InMemoryDataOperationEnum
::
AtomicMax
,
DataType
>
{
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
;
};
template
<
typename
DataType
>
struct
InMemoryDataOperatonSupportedOnDataType
<
InMemoryDataOperationEnum
::
Set
,
DataType
>
{
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
||
is_same
<
DataType
,
half_t
>::
value
||
is_same
<
DataType
,
bhalf_t
>::
value
||
is_same
<
DataType
,
int8_t
>::
value
||
is_same
<
DataType
,
int32_t
>::
value
;
};
template
<
typename
DataType
>
struct
InMemoryDataOperatonSupportedOnDataType
<
InMemoryDataOperationEnum
::
Add
,
DataType
>
{
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
||
is_same
<
DataType
,
half_t
>::
value
||
is_same
<
DataType
,
int8_t
>::
value
||
is_same
<
DataType
,
int32_t
>::
value
;
};
};
// end of namespace reduce
}
// end of namespace ck
...
...
include/ck/utility/sequence.hpp
View file @
b097be17
#ifndef CK_SEQUENCE_HPP
#define CK_SEQUENCE_HPP
#pragma once
#include "integral_constant.hpp"
#include "type.hpp"
...
...
@@ -241,7 +240,13 @@ struct arithmetic_sequence_gen
}
};
using
type
=
typename
sequence_gen
<
(
IEnd
-
IBegin
)
/
Increment
,
F
>::
type
;
using
type0
=
typename
sequence_gen
<
(
IEnd
-
IBegin
)
/
Increment
,
F
>::
type
;
using
type1
=
Sequence
<>
;
static
constexpr
bool
kHasContent
=
(
Increment
>
0
&&
IBegin
<
IEnd
)
||
(
Increment
<
0
&&
IBegin
>
IEnd
);
using
type
=
typename
conditional
<
kHasContent
,
type0
,
type1
>::
type
;
};
// uniform sequence
...
...
@@ -882,5 +887,10 @@ __host__ __device__ constexpr bool sequence_all_of(Seq, F f)
return
flag
;
}
template
<
typename
Sx
,
typename
Sy
>
using
sequence_merge_t
=
typename
sequence_merge
<
Sx
,
Sy
>::
type
;
template
<
index_t
NSize
,
index_t
I
>
using
uniform_sequence_gen_t
=
typename
uniform_sequence_gen
<
NSize
,
I
>::
type
;
}
// namespace ck
#endif
include/ck/utility/tuple.hpp
View file @
b097be17
#ifndef CK_TUPLE_HPP
#define CK_TUPLE_HPP
#pragma once
#include "integral_constant.hpp"
#include "sequence.hpp"
...
...
@@ -17,14 +16,18 @@ struct TupleElementKey
};
template
<
typename
Key
,
typename
Data
>
struct
TupleElement
struct
TupleElement
KeyData
{
__host__
__device__
constexpr
TupleElement
()
=
default
;
#if 0 // workaround compiler complaint about implicitly-deleted default constructor
__host__ __device__ constexpr TupleElementKeyData() = default;
#else
__host__
__device__
constexpr
TupleElementKeyData
()
:
mData
{}
{}
#endif
template
<
typename
T
,
typename
enable_if
<!
is_same
<
remove_cvref_t
<
T
>,
TupleElement
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleElement
(
T
&&
v
)
:
mData
(
std
::
forward
<
T
>
(
v
))
template
<
typename
T
,
typename
enable_if
<!
is_same
<
remove_cvref_t
<
T
>,
TupleElementKeyData
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleElement
KeyData
(
T
&&
v
)
:
mData
(
std
::
forward
<
T
>
(
v
))
{
}
...
...
@@ -32,20 +35,21 @@ struct TupleElement
};
template
<
typename
Key
,
typename
Data
>
__host__
__device__
constexpr
const
Data
&
get_tuple_element
(
const
TupleElement
<
Key
,
Data
>&
x
)
__host__
__device__
constexpr
const
Data
&
get_tuple_element_data
(
const
TupleElementKeyData
<
Key
,
Data
>&
x
)
{
return
static_cast
<
const
Data
&>
(
x
.
mData
);
}
template
<
typename
Key
,
typename
Data
>
__host__
__device__
constexpr
Data
&
get_tuple_element
(
TupleElement
<
Key
,
Data
>&
x
)
__host__
__device__
constexpr
Data
&
get_tuple_element
_data
(
TupleElement
KeyData
<
Key
,
Data
>&
x
)
{
return
x
.
mData
;
}
// TODO: not sure the use of reference is correct
template
<
typename
Key
,
typename
Data
>
__host__
__device__
constexpr
Data
&&
get_tuple_element
(
TupleElement
<
Key
,
Data
>&&
x
)
__host__
__device__
constexpr
Data
&&
get_tuple_element
_data
(
TupleElement
KeyData
<
Key
,
Data
>&&
x
)
{
return
static_cast
<
Data
&&>
(
x
.
mData
);
}
...
...
@@ -54,7 +58,7 @@ template <typename Indices, typename... Xs>
struct
TupleImpl
;
template
<
index_t
...
Is
,
typename
...
Xs
>
struct
TupleImpl
<
Sequence
<
Is
...
>
,
Xs
...
>
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
...
struct
TupleImpl
<
Sequence
<
Is
...
>
,
Xs
...
>
:
TupleElement
KeyData
<
TupleElementKey
<
Is
>
,
Xs
>
...
{
__host__
__device__
constexpr
TupleImpl
()
=
default
;
...
...
@@ -63,13 +67,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
!
is_same
<
remove_cvref_t
<
Y
>,
TupleImpl
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleImpl
(
Y
&&
y
)
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Y
>
(
y
))...
:
TupleElement
KeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Y
>
(
y
))...
{
}
template
<
typename
...
Ys
,
typename
enable_if
<
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleImpl
(
Ys
&&
...
ys
)
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Ys
>
(
ys
))...
:
TupleElement
KeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Ys
>
(
ys
))...
{
static_assert
(
sizeof
...(
Is
)
==
sizeof
...(
Xs
)
&&
sizeof
...(
Is
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
...
...
@@ -78,15 +82,15 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
__host__
__device__
static
constexpr
index_t
Size
()
{
return
sizeof
...(
Xs
);
}
template
<
index_t
I
>
__host__
__device__
constexpr
const
auto
&
GetElementByKey
(
TupleElementKey
<
I
>
)
const
__host__
__device__
constexpr
const
auto
&
GetElement
Data
ByKey
(
TupleElementKey
<
I
>
)
const
{
return
get_tuple_element
<
TupleElementKey
<
I
>>
(
*
this
);
return
get_tuple_element
_data
<
TupleElementKey
<
I
>>
(
*
this
);
}
template
<
index_t
I
>
__host__
__device__
constexpr
auto
&
GetElementByKey
(
TupleElementKey
<
I
>
)
__host__
__device__
constexpr
auto
&
GetElement
Data
ByKey
(
TupleElementKey
<
I
>
)
{
return
get_tuple_element
<
TupleElementKey
<
I
>>
(
*
this
);
return
get_tuple_element
_data
<
TupleElementKey
<
I
>>
(
*
this
);
}
};
...
...
@@ -121,7 +125,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__
__device__
constexpr
const
auto
&
At
(
Number
<
I
>
)
const
{
static_assert
(
I
<
base
::
Size
(),
"wrong! out of range"
);
return
base
::
GetElementByKey
(
detail
::
TupleElementKey
<
I
>
{});
return
base
::
GetElement
Data
ByKey
(
detail
::
TupleElementKey
<
I
>
{});
}
// write access
...
...
@@ -129,7 +133,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__
__device__
constexpr
auto
&
At
(
Number
<
I
>
)
{
static_assert
(
I
<
base
::
Size
(),
"wrong! out of range"
);
return
base
::
GetElementByKey
(
detail
::
TupleElementKey
<
I
>
{});
return
base
::
GetElement
Data
ByKey
(
detail
::
TupleElementKey
<
I
>
{});
}
// read access
...
...
@@ -159,6 +163,31 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
};
template
<
>
struct
Tuple
<>
{
__host__
__device__
constexpr
Tuple
()
=
default
;
__host__
__device__
static
constexpr
index_t
Size
()
{
return
0
;
}
template
<
typename
T
>
__host__
__device__
constexpr
auto
operator
=
(
const
T
&
)
{
return
*
this
;
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
};
template
<
index_t
I
,
typename
TTuple
>
struct
tuple_element
{
using
type
=
decltype
(
TTuple
{}.
At
(
Number
<
I
>
{}));
};
template
<
index_t
I
,
typename
TTuple
>
using
tuple_element_t
=
typename
tuple_element
<
I
,
TTuple
>::
type
;
template
<
typename
...
Xs
>
__host__
__device__
constexpr
auto
make_tuple
(
Xs
&&
...
xs
)
{
...
...
@@ -173,4 +202,3 @@ constexpr Tuple<Args&...> tie(Args&... args) noexcept
}
}
// namespace ck
#endif
include/ck/utility/tuple_helper.hpp
View file @
b097be17
#ifndef CK_TUPLE_HELPER_HPP
#define CK_TUPLE_HELPER_HPP
#pragma once
#include "functional4.hpp"
#include "tuple.hpp"
...
...
@@ -20,6 +19,17 @@ __host__ __device__ constexpr auto generate_tie(F&& f, Number<N>)
typename
arithmetic_sequence_gen
<
0
,
N
,
1
>::
type
{});
}
// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
template
<
typename
...
X
,
typename
...
Y
>
__host__
__device__
constexpr
auto
concat_tuple_of_reference
(
const
Tuple
<
X
&
...
>&
tx
,
const
Tuple
<
Y
&
...
>&
ty
)
{
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
tx
,
ty
);
}
namespace
detail
{
template
<
typename
F
,
typename
X
,
index_t
...
Is
>
...
...
@@ -66,4 +76,3 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y,
}
}
// namespace ck
#endif
library/include/ck/library/host_tensor/host_reduction.hpp
View file @
b097be17
...
...
@@ -174,15 +174,18 @@ struct ReductionHost
const
InDataType
*
in_data
,
float
beta
,
OutDataType
*
out_data
,
IndexDataType
*
out_indices
)
IndexDataType
*
out_indices
,
InElementwiseOperation
in_elementwise_op
,
AccElementwiseOperation
acc_elementwise_op
)
{
if
constexpr
(
OutputIndex
)
{
RunImpl_with_index
(
alpha
,
in_data
,
beta
,
out_data
,
out_indices
);
RunImpl_with_index
(
alpha
,
in_data
,
beta
,
out_data
,
out_indices
,
in_elementwise_op
,
acc_elementwise_op
);
}
else
{
RunImpl_no_index
(
alpha
,
in_data
,
beta
,
out_data
);
RunImpl_no_index
(
alpha
,
in_data
,
beta
,
out_data
,
in_elementwise_op
,
acc_elementwise_op
);
};
};
...
...
@@ -190,7 +193,9 @@ struct ReductionHost
const
InDataType
*
in_data
,
float
beta
,
OutDataType
*
out_data
,
IndexDataType
*
out_indices
)
IndexDataType
*
out_indices
,
InElementwiseOperation
in_elementwise_op
,
AccElementwiseOperation
acc_elementwise_op
)
{
using
ck
::
float_equal_one
;
using
ck
::
float_equal_zero
;
...
...
@@ -200,12 +205,10 @@ struct ReductionHost
ReduceOperation
,
AccDataType
,
IndexDataType
>
;
InElementwiseOperation
in_elementwise_op
(
divider
);
AccElementwiseOperation
acc_elementwise_op
(
divider
);
if
constexpr
(
NumInvariantDim
==
0
)
{
AccDataType
accuVal
=
ReduceOperation
::
GetIdentityValue
();
AccDataType
accuVal
=
ReduceOperation
::
template
GetIdentityValue
<
AccDataType
>
();
IndexDataType
accuIndex
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
reduce_dim_indexes
.
size
();
i
++
)
...
...
@@ -236,7 +239,7 @@ struct ReductionHost
else
{
auto
thread_reduce_func
=
[
&
](
auto
invariant_index
)
{
AccDataType
accuVal
=
ReduceOperation
::
GetIdentityValue
();
AccDataType
accuVal
=
ReduceOperation
::
template
GetIdentityValue
<
AccDataType
>
();
IndexDataType
accuIndex
=
0
;
auto
offset_invariant
=
...
...
@@ -297,7 +300,12 @@ struct ReductionHost
};
};
void
RunImpl_no_index
(
float
alpha
,
const
InDataType
*
in_data
,
float
beta
,
OutDataType
*
out_data
)
void
RunImpl_no_index
(
float
alpha
,
const
InDataType
*
in_data
,
float
beta
,
OutDataType
*
out_data
,
InElementwiseOperation
in_elementwise_op
,
AccElementwiseOperation
acc_elementwise_op
)
{
using
ck
::
float_equal_one
;
using
ck
::
float_equal_zero
;
...
...
@@ -306,12 +314,9 @@ struct ReductionHost
using
Accumulation
=
ck
::
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
ReduceOperation
,
AccDataType
>
;
InElementwiseOperation
in_elementwise_op
(
divider
);
AccElementwiseOperation
acc_elementwise_op
(
divider
);
if
constexpr
(
NumInvariantDim
==
0
)
{
AccDataType
accuVal
=
ReduceOperation
::
GetIdentityValue
();
AccDataType
accuVal
=
ReduceOperation
::
template
GetIdentityValue
<
AccDataType
>
();
for
(
const
auto
&
reduce_index
:
reduce_dim_indexes
)
{
...
...
@@ -338,7 +343,7 @@ struct ReductionHost
else
{
auto
thread_reduce_func
=
[
&
](
auto
invariant_index
)
{
AccDataType
accuVal
=
ReduceOperation
::
GetIdentityValue
();
AccDataType
accuVal
=
ReduceOperation
::
template
GetIdentityValue
<
AccDataType
>
();
auto
offset_invariant
=
get_offset_from_index
<
NumInvariantDim
>
(
invariantStrides
,
invariant_index
);
...
...
library/include/ck/library/host_tensor/host_tensor.hpp
View file @
b097be17
...
...
@@ -107,6 +107,11 @@ struct HostTensorDescriptor
return
std
::
inner_product
(
iss
.
begin
(),
iss
.
end
(),
mStrides
.
begin
(),
std
::
size_t
{
0
});
}
std
::
size_t
GetOffsetFromMultiIndex
(
std
::
vector
<
std
::
size_t
>
iss
)
const
{
return
std
::
inner_product
(
iss
.
begin
(),
iss
.
end
(),
mStrides
.
begin
(),
std
::
size_t
{
0
});
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
HostTensorDescriptor
&
desc
);
private:
...
...
@@ -212,6 +217,54 @@ struct Tensor
Tensor
(
const
HostTensorDescriptor
&
desc
)
:
mDesc
(
desc
),
mData
(
mDesc
.
GetElementSpace
())
{}
Tensor
(
const
Tensor
&
other
)
:
mDesc
(
other
.
mDesc
),
mData
(
other
.
mData
)
{}
template
<
typename
F
>
void
ForEach_impl
(
F
&&
f
,
std
::
vector
<
size_t
>&
idx
,
size_t
rank
)
{
if
(
rank
==
mDesc
.
GetNumOfDimension
())
{
f
(
*
this
,
idx
);
return
;
}
// else
for
(
size_t
i
=
0
;
i
<
mDesc
.
GetLengths
()[
rank
];
i
++
)
{
idx
[
rank
]
=
i
;
ForEach_impl
(
std
::
forward
<
F
>
(
f
),
idx
,
rank
+
1
);
}
}
template
<
typename
F
>
void
ForEach
(
F
&&
f
)
{
std
::
vector
<
size_t
>
idx
(
mDesc
.
GetNumOfDimension
(),
0
);
ForEach_impl
(
std
::
forward
<
F
>
(
f
),
idx
,
size_t
(
0
));
}
template
<
typename
F
>
void
ForEach_impl
(
const
F
&&
f
,
std
::
vector
<
size_t
>&
idx
,
size_t
rank
)
const
{
if
(
rank
==
mDesc
.
GetNumOfDimension
())
{
f
(
*
this
,
idx
);
return
;
}
// else
for
(
size_t
i
=
0
;
i
<
mDesc
.
GetLengths
()[
rank
];
i
++
)
{
idx
[
rank
]
=
i
;
ForEach_impl
(
std
::
forward
<
const
F
>
(
f
),
idx
,
rank
+
1
);
}
}
template
<
typename
F
>
void
ForEach
(
const
F
&&
f
)
const
{
std
::
vector
<
size_t
>
idx
(
mDesc
.
GetNumOfDimension
(),
0
);
ForEach_impl
(
std
::
forward
<
const
F
>
(
f
),
idx
,
size_t
(
0
));
}
template
<
typename
G
>
void
GenerateTensorValue
(
G
g
,
std
::
size_t
num_thread
=
1
)
{
...
...
@@ -272,6 +325,16 @@ struct Tensor
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
}
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
}
const
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
const
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
}
typename
std
::
vector
<
T
>::
iterator
begin
()
{
return
mData
.
begin
();
}
typename
std
::
vector
<
T
>::
iterator
end
()
{
return
mData
.
end
();
}
...
...
@@ -285,7 +348,8 @@ struct Tensor
};
template
<
typename
X
>
HostTensorDescriptor
::
HostTensorDescriptor
(
const
std
::
vector
<
X
>&
lens
)
:
mLens
(
lens
)
HostTensorDescriptor
::
HostTensorDescriptor
(
const
std
::
vector
<
X
>&
lens
)
:
mLens
(
lens
.
begin
(),
lens
.
end
())
{
this
->
CalculateStrides
();
}
...
...
@@ -293,7 +357,7 @@ HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens) : mLens(l
template
<
typename
X
,
typename
Y
>
HostTensorDescriptor
::
HostTensorDescriptor
(
const
std
::
vector
<
X
>&
lens
,
const
std
::
vector
<
Y
>&
strides
)
:
mLens
(
lens
),
mStrides
(
strides
)
:
mLens
(
lens
.
begin
(),
lens
.
end
()),
mStrides
(
strides
.
begin
(),
strides
.
end
()
)
{
}
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment